浏览代码

add comments

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
dbff06cd
共有 2 个文件被更改,包括 7 次插入2 次删除
  1. 1
      ml-agents/mlagents/trainers/saver/saver.py
  2. 8
      ml-agents/mlagents/trainers/torch/model_serialization.py

1
ml-agents/mlagents/trainers/saver/saver.py


def __init__(self):
pass
@abc.abstractmethod
def register(self, module: Any) -> None:
"""
Register the modules to the Saver.

8
ml-agents/mlagents/trainers/torch/model_serialization.py


dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)])
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size])
# Need to pass all posslible inputs since currently keyword arguments is not
# supported by torch.nn.export()
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories)
# Input names can only contain actual input used since in torch.nn.export
# it maps input_names only to input nodes that exist in the graph
self.input_names = []
self.dynamic_axes = {"action": {0: "batch"}, "action_probs": {0: "batch"}}
if self.policy.use_vec_obs:

"is_continuous_control",
"action_output_shape",
]
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories)
def export_policy_model(self, output_filepath: str) -> None:
"""

正在加载...
取消
保存