浏览代码

[Bug fix] Export all branches for discrete control torch (#4491)

* Export all branches for discrete control torch

* [skip ci] Changelog edits

* Update ml-agents/mlagents/trainers/torch/networks.py

Co-authored-by: Ruo-Ping (Rachel) Dong <ruoping.dong@unity3d.com>

* Update ml-agents/mlagents/trainers/torch/networks.py

* Fix formatting

Co-authored-by: Ruo-Ping (Rachel) Dong <ruoping.dong@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
069f10b9
共有 2 个文件被更改,包括 7 次插入5 次删除
  1. 1
      com.unity.ml-agents/CHANGELOG.md
  2. 11
      ml-agents/mlagents/trainers/torch/networks.py

1
com.unity.ml-agents/CHANGELOG.md


### Bug Fixes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Fixed a bug in exporting Pytorch models when using multiple discrete actions. (#4491)
## [1.4.0-preview] - 2020-09-16

11
ml-agents/mlagents/trainers/torch/networks.py


self.is_continuous_int = torch.nn.Parameter(
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
)
self.act_size_vector = torch.nn.Parameter(torch.Tensor(act_size))
self.act_size_vector = torch.nn.Parameter(
torch.Tensor([sum(act_size)]), requires_grad=False
)
self.network_body = NetworkBody(observation_shapes, network_settings)
if network_settings.memory is not None:
self.encoding_size = network_settings.memory.memory_size // 2

Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
action_list = self.sample_action(dists)
sampled_actions = torch.stack(action_list, dim=-1)
action_out = sampled_actions
action_list = self.sample_action(dists)
action_out = torch.stack(action_list, dim=-1)
action_out = dists[0].all_log_prob()
action_out = torch.cat([dist.all_log_prob() for dist in dists], dim=1)
return (
action_out,
self.version_number,

正在加载...
取消
保存