浏览代码

Change dimensions of recurrent in to BxN for ONNX

/develop/reshapeonnxmemories
Ervin Teng 4 年前
当前提交
ba29b6b4
共有 2 个文件被更改,包括 5 次插入4 次删除
  1. 6
      ml-agents/mlagents/trainers/torch/model_serialization.py
  2. 3
      ml-agents/mlagents/trainers/torch/networks.py

6
ml-agents/mlagents/trainers/torch/model_serialization.py


# cause problem to barracuda import.
self.policy = policy
batch_dim = [1]
seq_len_dim = [1]
dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])]
# create input shape of NCHW
# (It's NHWC in self.policy.behavior_spec.observation_shapes)

if len(shape) == 3
]
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)])
dummy_memories = torch.zeros(
batch_dim + seq_len_dim + [self.policy.export_memory_size]
)
# Assume sequence length is 1
dummy_memories = torch.zeros(batch_dim + [self.policy.export_memory_size])
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories)

3
ml-agents/mlagents/trainers/torch/networks.py


"""
Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""
# Barracuda prefers memories to be in format (batch, enc_size) when seq_len is 1
if memories is not None:
memories = memories.unsqueeze(1)
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
if self.act_type == ActionType.CONTINUOUS:
action_list = self.sample_action(dists)

正在加载...
取消
保存