浏览代码

fix export input_name

/develop/add-fire/export-discrete
Ruo-Ping Dong 4 年前
当前提交
cf1e7ca0
共有 2 个文件被更改,包括 11 次插入25 次删除
  1. 30
      ml-agents/mlagents/trainers/torch/model_serialization.py
  2. 6
      ml-agents/mlagents/trainers/torch/networks.py

30
ml-agents/mlagents/trainers/torch/model_serialization.py


def __init__(self, policy):
self.policy = policy
batch_dim = [1]
seq_len_dim = [1]
dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])]
dummy_vis_obs = [
torch.zeros(batch_dim + list(shape))

dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)])
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size])
dummy_memories = torch.zeros(batch_dim + seq_len_dim + [self.policy.m_size])
# Need to pass all possible inputs since currently keyword arguments is not
# supported by torch.nn.export()
# Input names can only contain actual input used since in torch.nn.export
# it maps input_names only to input nodes that exist in the graph
self.input_names = []
self.dynamic_axes = {"action": {0: "batch"}, "action_probs": {0: "batch"}}
if self.policy.use_vec_obs:
self.input_names.append("vector_observation")
self.dynamic_axes.update({"vector_observation": {0: "batch"}})
for i in range(self.policy.vis_obs_size):
self.input_names.append(f"visual_observation_{i}")
self.dynamic_axes.update({f"visual_observation_{i}": {0: "batch"}})
if not self.policy.use_continuous_act:
self.input_names.append("action_masks")
self.dynamic_axes.update({"action_masks": {0: "batch"}})
if self.policy.use_recurrent:
self.input_names.append("memories")
self.dynamic_axes.update({"memories": {0: "batch"}})
self.input_names = (
["vector_observation"]
+ [f"visual_observation_{i}" for i in range(self.policy.vis_obs_size)]
+ ["action_masks", "memories"]
)
self.output_names = [
"action",

"is_continuous_control",
"action_output_shape",
]
self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}
self.dynamic_axes.update({"action": {0: "batch"}, "action_probs": {0: "batch"}})
def export_policy_model(self, output_filepath: str) -> None:
"""

self.policy.actor_critic,
self.dummy_input,
onnx_output_path,
verbose=False,
opset_version=SerializationSettings.onnx_opset,
input_names=self.input_names,
output_names=self.output_names,

6
ml-agents/mlagents/trainers/torch/networks.py


vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
"""
Forward pass of the Actor for inference. This is required for export to ONNX, and

vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
dists, _ = self.get_dists(
vec_inputs, vis_inputs, masks, memories, sequence_length
)
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
action_list = self.sample_action(dists)
sampled_actions = torch.stack(action_list, dim=-1)
if self.act_type == ActionType.CONTINUOUS:

正在加载...
取消
保存