浏览代码

Rename NNCheckpoint to ModelCheckpoint as Model can be NN or ONNX (#4540)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
badca342
共有 5 个文件被更改,包括 31 次插入27 次删除
  1. 10
      ml-agents/mlagents/trainers/policy/checkpoint_manager.py
  2. 4
      ml-agents/mlagents/trainers/sac/trainer.py
  3. 8
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  4. 24
      ml-agents/mlagents/trainers/tests/test_training_status.py
  5. 12
      ml-agents/mlagents/trainers/trainer/rl_trainer.py

10
ml-agents/mlagents/trainers/policy/checkpoint_manager.py


@attr.s(auto_attribs=True)
class NNCheckpoint:
class ModelCheckpoint:
steps: int
file_path: str
reward: Optional[float]

class NNCheckpointManager:
class ModelCheckpointManager:
@staticmethod
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]:
checkpoint_list = GlobalTrainingStatus.get_parameter_state(

while len(checkpoints) > keep_checkpoints:
if keep_checkpoints <= 0 or len(checkpoints) == 0:
break
NNCheckpointManager.remove_checkpoint(checkpoints.pop(0))
ModelCheckpointManager.remove_checkpoint(checkpoints.pop(0))
cls, behavior_name: str, new_checkpoint: NNCheckpoint, keep_checkpoints: int
cls, behavior_name: str, new_checkpoint: ModelCheckpoint, keep_checkpoints: int
) -> None:
"""
Make room for new checkpoint if needed and insert new checkpoint information.

@classmethod
def track_final_checkpoint(
cls, behavior_name: str, final_checkpoint: NNCheckpoint
cls, behavior_name: str, final_checkpoint: ModelCheckpoint
) -> None:
"""
Ensures number of checkpoints stored is within the max number of checkpoints

4
ml-agents/mlagents/trainers/sac/trainer.py


import os
import numpy as np
from mlagents.trainers.policy.checkpoint_manager import NNCheckpoint
from mlagents.trainers.policy.checkpoint_manager import ModelCheckpoint
from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed

self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer
def _checkpoint(self) -> NNCheckpoint:
def _checkpoint(self) -> ModelCheckpoint:
"""
Writes a checkpoint model to memory
Overrides the default to save the replay buffer.

8
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


from unittest import mock
import pytest
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.policy.checkpoint_manager import NNCheckpoint
from mlagents.trainers.policy.checkpoint_manager import ModelCheckpoint
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.tests.test_buffer import construct_fake_buffer
from mlagents.trainers.agent_processor import AgentManagerQueue

"framework", [FrameworkType.TENSORFLOW, FrameworkType.PYTORCH], ids=["tf", "torch"]
)
@mock.patch("mlagents.trainers.trainer.trainer.StatsReporter.write_stats")
@mock.patch("mlagents.trainers.trainer.rl_trainer.NNCheckpointManager.add_checkpoint")
@mock.patch(
"mlagents.trainers.trainer.rl_trainer.ModelCheckpointManager.add_checkpoint"
)
def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary, framework):
trainer = create_rl_trainer(framework)
mock_policy = mock.Mock()

add_checkpoint_calls = [
mock.call(
trainer.brain_name,
NNCheckpoint(
ModelCheckpoint(
step,
f"{trainer.model_saver.model_path}/{trainer.brain_name}-{step}.{export_ext}",
None,

24
ml-agents/mlagents/trainers/tests/test_training_status.py


GlobalTrainingStatus,
)
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpointManager,
NNCheckpoint,
ModelCheckpointManager,
ModelCheckpoint,
)

brain_name, StatusType.CHECKPOINTS, test_checkpoint_list
)
new_checkpoint_4 = NNCheckpoint(
new_checkpoint_4 = ModelCheckpoint(
NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4
new_checkpoint_5 = NNCheckpoint(
new_checkpoint_5 = ModelCheckpoint(
NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4
final_model = NNCheckpoint(current_step, final_model_path, 3.294, final_model_time)
final_model = ModelCheckpoint(
current_step, final_model_path, 3.294, final_model_time
)
NNCheckpointManager.track_final_checkpoint(brain_name, final_model)
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
ModelCheckpointManager.track_final_checkpoint(brain_name, final_model)
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4
check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][
StatusType.CHECKPOINTS.value

12
ml-agents/mlagents/trainers/trainer/rl_trainer.py


import time
import attr
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpoint,
NNCheckpointManager,
ModelCheckpoint,
ModelCheckpointManager,
)
from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed

return sum(rewards) / len(rewards)
@timed
def _checkpoint(self) -> NNCheckpoint:
def _checkpoint(self) -> ModelCheckpoint:
"""
Checkpoints the policy associated with this trainer.
"""

)
checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self.step)
export_ext = "nn" if self.framework == FrameworkType.TENSORFLOW else "onnx"
new_checkpoint = NNCheckpoint(
new_checkpoint = ModelCheckpoint(
NNCheckpointManager.add_checkpoint(
ModelCheckpointManager.add_checkpoint(
self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints
)
return new_checkpoint

final_checkpoint = attr.evolve(
model_checkpoint, file_path=f"{self.model_saver.model_path}.{export_ext}"
)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
ModelCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
@abc.abstractmethod
def _update_policy(self) -> bool:

正在加载...
取消
保存