浏览代码

fix export input name

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
6a3b38e1
共有 1 个文件被更改,包括 6 次插入6 次删除
  1. 12
      ml-agents/mlagents/trainers/torch/model_serialization.py

12
ml-agents/mlagents/trainers/torch/model_serialization.py


dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size])
self.input_names = []
self.dynamic_axes = {"action": {0: 'batch'}, "action_probs": {0: 'batch'}}
self.dynamic_axes = {"action": {0: "batch"}, "action_probs": {0: "batch"}}
self.dynamic_axes.update({"vector_observation": {0: 'batch'}})
self.dynamic_axes.update({"vector_observation": {0: "batch"}})
self.dynamic_axes.update({"visual_observation": {0: 'batch'}})
self.dynamic_axes.update({"visual_observation": {0: "batch"}})
self.input_names.append("action_mask")
self.dynamic_axes.update({"action_mask": {0: 'batch'}})
self.input_names.append("action_masks")
self.dynamic_axes.update({"action_masks": {0: "batch"}})
self.dynamic_axes.update({"memories": {0: 'batch'}})
self.dynamic_axes.update({"memories": {0: "batch"}})
self.output_names = [
"action",

正在加载...
取消
保存