浏览代码

move checkpoint_path logic to saver

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
4e87b422
共有 6 个文件被更改,包括 14 次插入14 次删除
  1. 5
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 7
      ml-agents/mlagents/trainers/saver/tf_saver.py
  3. 4
      ml-agents/mlagents/trainers/saver/torch_saver.py
  4. 5
      ml-agents/mlagents/trainers/torch/model_serialization.py
  5. 4
      ml-agents/mlagents/trainers/torch/networks.py
  6. 3
      ml-agents/mlagents/trainers/trainer/rl_trainer.py

5
ml-agents/mlagents/trainers/policy/torch_policy.py


Gets current model step.
:return: current model step.
"""
step = self.global_step.get_step()
return step
return self.global_step.current_step
def _set_step(self, step: int) -> int:
"""

"""
self.global_step.set_step(step)
self.global_step.current_step = step
def increment_step(self, n_steps):
"""

7
ml-agents/mlagents/trainers/saver/tf_saver.py


def register(self, module_dict):
pass
def save_checkpoint(self, checkpoint_path: str, brain_name: str) -> None:
def save_checkpoint(self, brain_name: str, step: int) -> None:
"""
Checkpoints the policy on disk.

print('save checkpoint_path:', checkpoint_path)
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")
# Save the TF checkpoint and graph definition
with self.graph.as_default():
if self.saver:

)
# also save the policy so we have optimized model files for each checkpoint
self.export(checkpoint_path, brain_name)
return checkpoint_path
def export(self, output_filepath: str, brain_name: str) -> None:
"""

:param output_filepath: path (without suffix) for the model file(s)
:param brain_name: Brain name of brain to be trained.
"""
print('export output_filepath:', output_filepath)
export_policy_model(output_filepath, brain_name, self.graph, self.sess)
def maybe_load(self):

self.policy._initialize_graph()
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
print('load model_path:', model_path)
with self.graph.as_default():
logger.info(f"Loading model from {model_path}.")
ckpt = tf.train.get_checkpoint_state(model_path)

4
ml-agents/mlagents/trainers/saver/torch_saver.py


def register(self, module):
self.modules.update(module.get_modules())
def save_checkpoint(self, checkpoint_path: str, brain_name: str) -> None:
def save_checkpoint(self, brain_name: str, step: int) -> str:
"""
Checkpoints the policy on disk.

if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")
return checkpoint_path
def maybe_load(self):
# If there is an initialize path, load from that. Else, load from the set model path.

5
ml-agents/mlagents/trainers/torch/model_serialization.py


dummy_masks = [torch.ones([1] + self.policy.actor_critic.act_size)]
dummy_memories = [torch.zeros([1] + [self.policy.m_size])]
dummy_sequence_length = [torch.tensor([self.policy.sequence_length])]
self.input_names = ["vector_observation", "visual_observation", \
"action_mask", "memories", "sequence_length"]
self.output_names = ["action", "action_probs", "version_number", \

"""
if not os.path.exists(output_filepath):
os.makedirs(output_filepath)
onnx_output_path = f"{output_filepath}.onnx"
logger.info(f"Converting to {onnx_output_path}")

output_names=self.output_names,
dynamic_axes=self.dynamic_axes,
)
logger.info(f"Exported {onnx_output_path}.onnx")

4
ml-agents/mlagents/trainers/torch/networks.py


self._global_step = nn.Parameter(torch.Tensor([0]), requires_grad=False)
@property
def step(self):
def current_step(self):
@step.setter
@current_step.setter
def set_step(self, value):
self._global_step.data = value

3
ml-agents/mlagents/trainers/trainer/rl_trainer.py


logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
checkpoint_path = os.path.join(self.saver.model_path, f"{self.brain_name}-{self.step}")
self.saver.save_checkpoint(checkpoint_path, self.brain_name)
checkpoint_path = self.saver.save_checkpoint(self.brain_name, self.step)
new_checkpoint = NNCheckpoint(
int(self.step),
f"{checkpoint_path}.nn",

正在加载...
取消
保存