浏览代码

Fix export

/develop/add-fire
GitHub 4 年前
当前提交
347bde3d
共有 4 个文件被更改,包括 34 次插入42 次删除
  1. 11
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  2. 4
      ml-agents/mlagents/trainers/torch/layers.py
  3. 41
      ml-agents/mlagents/trainers/torch/model_serialization.py
  4. 20
      ml-agents/mlagents/trainers/torch/networks.py

11
ml-agents/mlagents/trainers/tests/torch/test_networks.py


assert act.shape == (1, 1)
# Test forward
actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward(
actions, ver_num, mem_size, is_cont, act_size_vec = actor.forward(
# This is different from above for ONNX export
assert act.shape == (
act_size[0],
1,
) # This is different from above for ONNX export
assert act.shape == (act_size[0], 1)
assert act.shape == (1, 1)
assert act.shape == tuple(act_size)
# TODO: Once export works properly. fix the shapes here.
assert mem_size == 0
assert is_cont == int(action_type == ActionType.CONTINUOUS)
assert act_size_vec == torch.tensor(act_size)

4
ml-agents/mlagents/trainers/torch/layers.py


def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
h0, c0 = torch.split(memories, self.hidden_size, dim=-1)
# We don't use torch.split here since it is not supported by Barracuda
h0 = memories[:, :, : self.hidden_size]
c0 = memories[:, :, self.hidden_size :]
hidden = (h0, c0)
lstm_out, hidden_out = self.lstm(input_tensor, hidden)
output_mem = torch.cat(hidden_out, dim=-1)

41
ml-agents/mlagents/trainers/torch/model_serialization.py


class ModelSerializer:
def __init__(self, policy):
# ONNX only support input in NCHW (channel first) format.
# Barracuda also expect to get data in NCHW.
# Any multi-dimentional input should follow that otherwise will
# cause problem to barracuda import.
seq_len_dim = [1]
# create input shape of NCHW
# (It's NHWC in self.policy.behavior_spec.observation_shapes)
torch.zeros(batch_dim + list(shape))
torch.zeros(batch_dim + [shape[2], shape[0], shape[1]])
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.export_memory_size])
dummy_memories = torch.zeros(
batch_dim + seq_len_dim + [self.policy.export_memory_size]
)
# Need to pass all possible inputs since currently keyword arguments is not
# supported by torch.nn.export()
# Input names can only contain actual input used since in torch.nn.export
# it maps input_names only to input nodes that exist in the graph
self.input_names = []
self.dynamic_axes = {"action": {0: "batch"}, "action_probs": {0: "batch"}}
if self.policy.use_vec_obs:
self.input_names.append("vector_observation")
self.dynamic_axes.update({"vector_observation": {0: "batch"}})
for i in range(self.policy.vis_obs_size):
self.input_names.append(f"visual_observation_{i}")
self.dynamic_axes.update({f"visual_observation_{i}": {0: "batch"}})
if not self.policy.use_continuous_act:
self.input_names.append("action_masks")
self.dynamic_axes.update({"action_masks": {0: "batch"}})
if self.policy.use_recurrent:
self.input_names.append("memories")
self.dynamic_axes.update({"memories": {0: "batch"}})
self.input_names = (
["vector_observation"]
+ [f"visual_observation_{i}" for i in range(self.policy.vis_obs_size)]
+ ["action_masks", "memories"]
)
"action_probs",
self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}
self.dynamic_axes.update({"action": {0: "batch"}})
def export_policy_model(self, output_filepath: str) -> None:
"""

self.policy.actor_critic,
self.dummy_input,
onnx_output_path,
verbose=False,
opset_version=SerializationSettings.onnx_opset,
input_names=self.input_names,
output_names=self.output_names,

20
ml-agents/mlagents/trainers/torch/networks.py


for idx, encoder in enumerate(self.visual_encoders):
vis_input = vis_inputs[idx]
vis_input = vis_input.permute([0, 3, 1, 2])
if not torch.onnx.is_in_onnx_export():
vis_input = vis_input.permute([0, 3, 1, 2])
hidden = encoder(vis_input)
encodes.append(hidden)

vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
) -> Tuple[torch.Tensor, int, int, int, int]:
"""
Forward pass of the Actor for inference. This is required for export to ONNX, and
the inputs and outputs of this method should not be changed without a respective change

vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, int, int, int, int]:
) -> Tuple[torch.Tensor, int, int, int, int]:
dists, _ = self.get_dists(
vec_inputs, vis_inputs, masks, memories, sequence_length
)
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
log_probs = dists[0].log_prob(sampled_actions)
action_out = sampled_actions
log_probs = dists[0].all_log_prob()
action_out = dists[0].all_log_prob()
sampled_actions,
log_probs,
action_out,
self.version_number,
torch.Tensor([self.network_body.memory_size]),
self.is_continuous_int,

正在加载...
取消
保存