浏览代码

make visual input channel first for export

/develop/add-fire/export-discrete
Ruo-Ping Dong 4 年前
当前提交
b8dbbc17
共有 2 个文件被更改,包括 3 次插入2 次删除
  1. 2
      ml-agents/mlagents/trainers/torch/model_serialization.py
  2. 3
      ml-agents/mlagents/trainers/torch/networks.py

2
ml-agents/mlagents/trainers/torch/model_serialization.py


seq_len_dim = [1]
dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])]
dummy_vis_obs = [
torch.zeros(batch_dim + list(shape))
torch.zeros(batch_dim + [shape[2], shape[0], shape[1]])
for shape in self.policy.behavior_spec.observation_shapes
if len(shape) == 3
]

3
ml-agents/mlagents/trainers/torch/networks.py


for idx, encoder in enumerate(self.visual_encoders):
vis_input = vis_inputs[idx]
vis_input = vis_input.permute([0, 3, 1, 2])
if not torch.onnx.is_in_onnx_export():
vis_input = vis_input.permute([0, 3, 1, 2])
hidden = encoder(vis_input)
encodes.append(hidden)

正在加载...
取消
保存