浏览代码

fix onnx input

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
6fbd862e
共有 1 个文件被更改,包括 6 次插入7 次删除
  1. 13
      ml-agents/mlagents/trainers/torch/model_serialization.py

13
ml-agents/mlagents/trainers/torch/model_serialization.py


def __init__(self, policy):
self.policy = policy
dummy_vec_obs = [torch.zeros([1] + [self.policy.vec_obs_size])]
dummy_vis_obs = [torch.zeros([1] + self.policy.vis_obs_shape)] \
dummy_vis_obs = [torch.zeros([1] + list(self.policy.vis_obs_shape))] \
dummy_masks = [torch.ones([1] + self.policy.actor_critic.act_size)]
dummy_memories = [torch.zeros([1] + [self.policy.m_size])]
dummy_sequence_length = [torch.tensor([self.policy.sequence_length])]
dummy_masks = torch.ones([1] + self.policy.actor_critic.act_size)
dummy_memories = torch.zeros([1] + [self.policy.m_size])
"action_mask", "memories", "sequence_length"]
"action_mask", "memories"]
"action_mask": [0], "memories": [0], "action": [0],"action_probs": [0]}
"action_mask": [0], "memories": [0], "action": [0],"action_probs": [0]}
dummy_masks, dummy_memories, dummy_sequence_length)
dummy_masks, dummy_memories)
def export_policy_model(self, output_filepath: str) -> None:
"""

正在加载...
取消
保存