浏览代码

Save checkpoint files as .nn files in checkpoint directory

/develop/model-store
PSankalp Patro 5 年前
当前提交
45c4ea36
共有 4 个文件被更改,包括 20 次插入6 次删除
  1. 11
      ml-agents/mlagents/model_serialization.py
  2. 2
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  3. 11
      ml-agents/mlagents/trainers/trainer/trainer.py
  4. 2
      ml-agents/mlagents/trainers/trainer_controller.py

11
ml-agents/mlagents/model_serialization.py


class SerializationSettings(NamedTuple):
model_path: str
brain_name: str
checkpoint_path: str = ""
convert_to_barracuda: bool = True
convert_to_onnx: bool = True
onnx_opset: int = 9

settings: SerializationSettings, graph: tf.Graph, sess: tf.Session
settings: SerializationSettings, graph: tf.Graph, sess: tf.Session, is_checkpoint: bool = False
) -> None:
"""
Exports latest saved model to .nn format for Unity embedding.

# Convert to barracuda
if settings.convert_to_barracuda:
tf2bc.convert(frozen_graph_def_path, settings.model_path + ".nn")
logger.info(f"Exported {settings.model_path}.nn file")
if is_checkpoint :
tf2bc.convert(frozen_graph_def_path, os.path.join(settings.model_path, f"{settings.checkpoint_path}.nn"))
logger.info(f"Exported {settings.checkpoint_path}.nn file")
else:
tf2bc.convert(frozen_graph_def_path, settings.model_path + ".nn")
logger.info(f"Exported {settings.model_path}.nn file")
# Save to onnx too (if we were able to import it)
if ONNX_EXPORT_ENABLED:

2
ml-agents/mlagents/trainers/trainer/rl_trainer.py


if step_after_process >= self._next_save_step and self.get_step != 0:
logger.info(f"Checkpointing model for {self.brain_name}.")
self.save_model(self.brain_name)
logger.info(f"Exporting a checkpoint model for {self.brain_name}.")
self.export_model(self.brain_name, is_checkpoint=True)
def advance(self) -> None:
"""

11
ml-agents/mlagents/trainers/trainer/trainer.py


"""
self.get_policy(name_behavior_id).save_model(self.get_step)
def export_model(self, name_behavior_id: str) -> None:
def export_model(self, name_behavior_id: str, is_checkpoint=False) -> None:
settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
export_policy_model(settings, policy.graph, policy.sess)
if is_checkpoint:
checkpoint_path = f"{name_behavior_id}-{self.get_step}"
settings = SerializationSettings(policy.model_path, policy.brain.brain_name, checkpoint_path)
else:
settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
export_policy_model(settings, policy.graph, policy.sess, is_checkpoint)
@abc.abstractmethod
def end_episode(self):

2
ml-agents/mlagents/trainers/trainer_controller.py


"Learning was interrupted. Please wait while the graph is generated."
)
self._save_model()
self._export_graph()
def _export_graph(self):
"""

for name_behavior_id in self.brain_name_to_identifier[brain_name]:
self.trainers[brain_name].export_model(name_behavior_id)
self.logger.info("Exported Model")
@staticmethod
def _create_output_path(output_path):

正在加载...
取消
保存