浏览代码

Always export one Action tensor (#4388)

/develop/add-fire/export-discrete
GitHub 4 年前
当前提交
a855cf09
共有 2 个文件被更改,包括 7 次插入16 次删除
  1. 12
      ml-agents/mlagents/trainers/torch/model_serialization.py
  2. 11
      ml-agents/mlagents/trainers/torch/networks.py

12
ml-agents/mlagents/trainers/torch/model_serialization.py


+ ["action_masks", "memories"]
)
if self.policy.use_continuous_act:
action_name = "action"
action_prob_name = "action_probs"
else:
action_name = "action_unused"
action_prob_name = "action"
action_name,
action_prob_name,
"action",
"version_number",
"memory_size",
"is_continuous_control",

self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}
self.dynamic_axes.update({"action": {0: "batch"}, "action_probs": {0: "batch"}})
self.dynamic_axes.update({"action": {0: "batch"}})
def export_policy_model(self, output_filepath: str) -> None:
"""

11
ml-agents/mlagents/trainers/torch/networks.py


vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
) -> Tuple[torch.Tensor, int, int, int, int]:
"""
Forward pass of the Actor for inference. This is required for export to ONNX, and
the inputs and outputs of this method should not be changed without a respective change

vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
) -> Tuple[torch.Tensor, int, int, int, int]:
"""
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""

if self.act_type == ActionType.CONTINUOUS:
log_probs = dists[0].log_prob(sampled_actions)
action_out = sampled_actions
log_probs = dists[0].all_log_prob()
action_out = dists[0].all_log_prob()
sampled_actions,
log_probs,
action_out,
self.version_number,
torch.Tensor([self.network_body.memory_size]),
self.is_continuous_int,

正在加载...
取消
保存