浏览代码

[bug-fix] Delete .pt checkpoints past keep-checkpoints (#5271)

* Manage non-ONNX files with checkpoint manager too

* Update tests

* Update training status version

* Change ticking of status file version
/check-for-ModelOverriders
GitHub 4 年前
当前提交
28eb43dd
共有 7 个文件被更改,包括 38 次插入16 次删除
  1. 7
      ml-agents/mlagents/trainers/model_saver/model_saver.py
  2. 8
      ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
  3. 15
      ml-agents/mlagents/trainers/policy/checkpoint_manager.py
  4. 11
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  5. 3
      ml-agents/mlagents/trainers/tests/test_training_status.py
  6. 8
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  7. 2
      ml-agents/mlagents/trainers/training_status.py

7
ml-agents/mlagents/trainers/model_saver/model_saver.py


# # Unity ML-Agents Toolkit
import abc
from typing import Any
from typing import Any, Tuple, List
class BaseModelSaver(abc.ABC):

pass
@abc.abstractmethod
def save_checkpoint(self, behavior_name: str, step: int) -> str:
def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]]:
:return: A Tuple of the path to the exported file, as well as a List of any
auxillary files that were returned. For instance, an exported file would be Model.onnx,
and the auxillary files would be [Model.pt] for PyTorch
"""
pass

8
ml-agents/mlagents/trainers/model_saver/torch_model_saver.py


import os
import shutil
from mlagents.torch_utils import torch
from typing import Dict, Union, Optional, cast
from typing import Dict, Union, Optional, cast, Tuple, List
from mlagents_envs.exception import UnityPolicyException
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.model_saver.model_saver import BaseModelSaver

self.policy = module
self.exporter = ModelSerializer(self.policy)
def save_checkpoint(self, behavior_name: str, step: int) -> str:
def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]]:
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}")

pytorch_ckpt_path = f"{checkpoint_path}.pt"
export_ckpt_path = f"{checkpoint_path}.onnx"
return checkpoint_path
return export_ckpt_path, [pytorch_ckpt_path]
def export(self, output_filepath: str, behavior_name: str) -> None:
if self.exporter is not None:

15
ml-agents/mlagents/trainers/policy/checkpoint_manager.py


file_path: str
reward: Optional[float]
creation_time: float
auxillary_file_paths: List[str] = attr.ib(factory=list)
class ModelCheckpointManager:

:param checkpoint: A checkpoint stored in checkpoint_list
"""
file_path: str = checkpoint["file_path"]
if os.path.exists(file_path):
os.remove(file_path)
logger.debug(f"Removed checkpoint model {file_path}.")
else:
logger.debug(f"Checkpoint at {file_path} could not be found.")
file_paths: List[str] = [checkpoint["file_path"]]
file_paths.extend(checkpoint["auxillary_file_paths"])
for file_path in file_paths:
if os.path.exists(file_path):
os.remove(file_path)
logger.debug(f"Removed checkpoint model {file_path}.")
else:
logger.debug(f"Checkpoint at {file_path} could not be found.")
return
@classmethod

11
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


def add_policy(self, mock_behavior_id, mock_policy):
def checkpoint_path(brain_name, step):
return os.path.join(self.model_saver.model_path, f"{brain_name}-{step}")
onnx_file_path = os.path.join(
self.model_saver.model_path, f"{brain_name}-{step}.onnx"
)
other_file_paths = [
os.path.join(self.model_saver.model_path, f"{brain_name}-{step}.pt")
]
return onnx_file_path, other_file_paths
self.policies[mock_behavior_id] = mock_policy
mock_model_saver = mock.Mock()

f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.{export_ext}",
None,
mock.ANY,
[
f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.pt"
],
),
trainer.trainer_settings.keep_checkpoints,
)

3
ml-agents/mlagents/trainers/tests/test_training_status.py


"file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"),
"reward": 1.312,
"creation_time": time.time(),
"auxillary_file_paths": [],
},
{
"steps": 2,

"auxillary_file_paths": [],
},
{
"steps": 3,

"auxillary_file_paths": [],
},
]
GlobalTrainingStatus.set_parameter_state(

8
ml-agents/mlagents/trainers/trainer/rl_trainer.py


logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self._step)
export_ext = "onnx"
export_path, auxillary_paths = self.model_saver.save_checkpoint(
self.brain_name, self._step
)
f"{checkpoint_path}.{export_ext}",
export_path,
auxillary_file_paths=auxillary_paths,
)
ModelCheckpointManager.add_checkpoint(
self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints

2
ml-agents/mlagents/trainers/training_status.py


logger = get_logger(__name__)
STATUS_FORMAT_VERSION = "0.2.0"
STATUS_FORMAT_VERSION = "0.3.0"
class StatusType(Enum):

正在加载...
取消
保存