浏览代码

small improvements

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
b4713baa
共有 6 个文件被更改,包括 12 次插入7 次删除
  1. 2
      ml-agents/mlagents/trainers/policy/tf_policy.py
  2. 2
      ml-agents/mlagents/trainers/policy/torch_policy.py
  3. 2
      ml-agents/mlagents/trainers/saver/tf_saver.py
  4. 2
      ml-agents/mlagents/trainers/saver/torch_saver.py
  5. 10
      ml-agents/mlagents/trainers/torch/model_serialization.py
  6. 1
      ml-agents/mlagents/trainers/trainer/rl_trainer.py

2
ml-agents/mlagents/trainers/policy/tf_policy.py


step = self.sess.run(self.global_step)
return step
def _set_step(self, step: int) -> int:
def set_step(self, step: int) -> int:
"""
Sets current model step to step without creating additional ops.
:param step: Step to set the current model step to.

2
ml-agents/mlagents/trainers/policy/torch_policy.py


"""
return self.global_step.current_step
def _set_step(self, step: int) -> int:
def set_step(self, step: int) -> int:
"""
Sets current model step to step without creating additional ops.
:param step: Step to set the current model step to.

2
ml-agents/mlagents/trainers/saver/tf_saver.py


)
self._check_model_version(__version__)
if reset_global_steps:
self.policy._set_step(0)
self.policy.set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path

2
ml-agents/mlagents/trainers/saver/torch_saver.py


for name, state_dict in saved_state_dict.items():
self.modules[name].load_state_dict(state_dict)
if reset_global_steps:
self.policy._set_step(0)
self.policy.set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path

10
ml-agents/mlagents/trainers/torch/model_serialization.py


class ModelSerializer:
def __init__(self, policy):
self.policy = policy
dummy_vec_obs = [torch.zeros([1] + [self.policy.vec_obs_size])]
# dimension for batch (and sequence_length if use recurrent)
dummy_dim = [1, 1] if self.policy.use_recurrent else [1]
dummy_vec_obs = [torch.zeros(dummy_dim + [self.policy.vec_obs_size])]
[torch.zeros([1] + list(self.policy.vis_obs_shape))]
[torch.zeros(dummy_dim + list(self.policy.vis_obs_shape))]
dummy_memories = torch.zeros([1] + [self.policy.m_size])
dummy_memories = torch.zeros(dummy_dim + [self.policy.m_size])
self.input_names = [
"vector_observation",

}
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories)
@staticmethod
def export_policy_model(self, output_filepath: str) -> None:
"""
Exports a Torch model for a Policy to .onnx format for Unity embedding.

1
ml-agents/mlagents/trainers/trainer/rl_trainer.py


"""
pass
@staticmethod
def create_saver(self, policy: Policy) -> BaseSaver:
if self.framework == "torch":
saver = TorchSaver( # type: ignore

正在加载...
取消
保存