浏览代码

Fix of the test for multi visual input

/develop/add-fire/export-discrete
vincentpierre 4 年前
当前提交
349cee77
共有 2 个文件被更改,包括 9 次插入14 次删除
  1. 5
      ml-agents/mlagents/trainers/policy/policy.py
  2. 18
      ml-agents/mlagents/trainers/torch/model_serialization.py

5
ml-agents/mlagents/trainers/policy/policy.py


self.vis_obs_size = sum(
1 for shape in behavior_spec.observation_shapes if len(shape) == 3
)
self.vis_obs_shape = (
[shape for shape in behavior_spec.observation_shapes if len(shape) == 3][0]
if self.vis_obs_size > 0
else None
)
self.use_continuous_act = behavior_spec.is_action_continuous()
self.num_branches = self.behavior_spec.action_size
self.previous_action_dict: Dict[str, np.array] = {}

18
ml-agents/mlagents/trainers/torch/model_serialization.py


self.policy = policy
batch_dim = [1]
dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])]
dummy_vis_obs = (
[torch.zeros(batch_dim + list(self.policy.vis_obs_shape))]
if self.policy.vis_obs_size > 0
else []
)
dummy_vis_obs = [
torch.zeros(batch_dim + list(shape))
for shape in self.policy.behavior_spec.observation_shapes
if len(shape) == 3
]
# Need to pass all posslible inputs since currently keyword arguments is not
# Need to pass all possible inputs since currently keyword arguments is not
# supported by torch.nn.export()
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories)

if self.policy.use_vec_obs:
self.input_names.append("vector_observation")
self.dynamic_axes.update({"vector_observation": {0: "batch"}})
if self.policy.use_vis_obs:
self.input_names.append("visual_observation")
self.dynamic_axes.update({"visual_observation": {0: "batch"}})
for i in range(self.policy.vis_obs_size):
self.input_names.append(f"visual_observation_{i}")
self.dynamic_axes.update({f"visual_observation_{i}": {0: "batch"}})
if not self.policy.use_continuous_act:
self.input_names.append("action_masks")
self.dynamic_axes.update({"action_masks": {0: "batch"}})

正在加载...
取消
保存