浏览代码

Revert to brain_name temporarily

/develop/add-fire/clean2
Ervin Teng 4 年前
当前提交
52a686d5
共有 1 个文件被更改,包括 4 次插入4 次删除
  1. 8
      ml-agents/mlagents/trainers/saver/torch_saver.py

8
ml-agents/mlagents/trainers/saver/torch_saver.py


self.policy = module
self.exporter = ModelSerializer(self.policy)
def save_checkpoint(self, behavior_name: str, step: int) -> str:
def save_checkpoint(self, brain_name: str, step: int) -> str:
checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}")
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")
self.export(checkpoint_path, behavior_name)
self.export(checkpoint_path, brain_name)
def export(self, output_filepath: str, behavior_name: str) -> None:
def export(self, output_filepath: str, brain_name: str) -> None:
if self.exporter is not None:
self.exporter.export_policy_model(output_filepath)

正在加载...
取消
保存