浏览代码

Cherry picked #4491 (#4493)

[Bug fix] Export all branches for discrete control torch
/release_7_branch
GitHub 4 年前
当前提交
10f3e1c7
共有 2 个文件被更改,包括 8 次插入5 次删除
  1. 2
      com.unity.ml-agents/CHANGELOG.md
  2. 11
      ml-agents/mlagents/trainers/torch/networks.py

2
com.unity.ml-agents/CHANGELOG.md


and this project adheres to
[Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [1.4.0-preview] - 2020-09-16
### Major Changes
#### com.unity.ml-agents (C#)

- Fixed the sample code in the custom SideChannel example. (#4466)
- A bug in the observation normalizer that would cause rewards to decrease
when using `--resume` was fixed. (#4463)
- Fixed a bug in exporting Pytorch models when using multiple discrete actions. (#4491)
## [1.3.0-preview] - 2020-08-12

11
ml-agents/mlagents/trainers/torch/networks.py


self.is_continuous_int = torch.nn.Parameter(
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
)
self.act_size_vector = torch.nn.Parameter(torch.Tensor(act_size))
self.act_size_vector = torch.nn.Parameter(
torch.Tensor([sum(act_size)]), requires_grad=False
)
self.network_body = NetworkBody(observation_shapes, network_settings)
if network_settings.memory is not None:
self.encoding_size = network_settings.memory.memory_size // 2

Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
action_list = self.sample_action(dists)
sampled_actions = torch.stack(action_list, dim=-1)
action_out = sampled_actions
action_list = self.sample_action(dists)
action_out = torch.stack(action_list, dim=-1)
action_out = dists[0].all_log_prob()
action_out = torch.cat([dist.all_log_prob() for dist in dists], dim=1)
return (
action_out,
self.version_number,

正在加载...
取消
保存