浏览代码
Convert checkpoints to .NN (#4127)
Convert checkpoints to .NN (#4127)
This change adds an export to .nn for each checkpoint generated by RLTrainer and adds a NNCheckpointManager to track the generated checkpoints and final model in training_status.json. Co-authored-by: Jonathan Harper <jharper+moar@unity3d.com>/MLA-1734-demo-provider
GitHub
5 年前
当前提交
84440f05
共有 22 个文件被更改,包括 476 次插入 和 182 次删除
-
4docs/Training-ML-Agents.md
-
20ml-agents/mlagents/model_serialization.py
-
17ml-agents/mlagents/trainers/ghost/trainer.py
-
1ml-agents/mlagents/trainers/learn.py
-
31ml-agents/mlagents/trainers/policy/tf_policy.py
-
1ml-agents/mlagents/trainers/ppo/trainer.py
-
22ml-agents/mlagents/trainers/sac/trainer.py
-
12ml-agents/mlagents/trainers/tests/test_barracuda_converter.py
-
4ml-agents/mlagents/trainers/tests/test_config_conversion.py
-
7ml-agents/mlagents/trainers/tests/test_nn_policy.py
-
27ml-agents/mlagents/trainers/tests/test_ppo.py
-
43ml-agents/mlagents/trainers/tests/test_rl_trainer.py
-
31ml-agents/mlagents/trainers/tests/test_sac.py
-
8ml-agents/mlagents/trainers/tests/test_trainer_controller.py
-
64ml-agents/mlagents/trainers/tests/test_training_status.py
-
66ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
22ml-agents/mlagents/trainers/trainer/trainer.py
-
15ml-agents/mlagents/trainers/trainer_controller.py
-
2ml-agents/mlagents/trainers/training_status.py
-
98ml-agents/mlagents/trainers/policy/checkpoint_manager.py
-
92ml-agents/mlagents/trainers/tests/test_tf_policy.py
-
71ml-agents/mlagents/trainers/tests/test_policy.py
|
|||
# # Unity ML-Agents Toolkit |
|||
from typing import Dict, Any, Optional, List |
|||
import os |
|||
import attr |
|||
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType |
|||
from mlagents_envs.logging_util import get_logger |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
@attr.s(auto_attribs=True) |
|||
class NNCheckpoint: |
|||
steps: int |
|||
file_path: str |
|||
reward: Optional[float] |
|||
creation_time: float |
|||
|
|||
|
|||
class NNCheckpointManager: |
|||
@staticmethod |
|||
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]: |
|||
checkpoint_list = GlobalTrainingStatus.get_parameter_state( |
|||
behavior_name, StatusType.CHECKPOINTS |
|||
) |
|||
if not checkpoint_list: |
|||
checkpoint_list = [] |
|||
GlobalTrainingStatus.set_parameter_state( |
|||
behavior_name, StatusType.CHECKPOINTS, checkpoint_list |
|||
) |
|||
return checkpoint_list |
|||
|
|||
@staticmethod |
|||
def remove_checkpoint(checkpoint: Dict[str, Any]) -> None: |
|||
""" |
|||
Removes a checkpoint stored in checkpoint_list. |
|||
If checkpoint cannot be found, no action is done. |
|||
|
|||
:param checkpoint: A checkpoint stored in checkpoint_list |
|||
""" |
|||
file_path: str = checkpoint["file_path"] |
|||
if os.path.exists(file_path): |
|||
os.remove(file_path) |
|||
logger.info(f"Removed checkpoint model {file_path}.") |
|||
else: |
|||
logger.info(f"Checkpoint at {file_path} could not be found.") |
|||
return |
|||
|
|||
@classmethod |
|||
def _cleanup_extra_checkpoints( |
|||
cls, checkpoints: List[Dict], keep_checkpoints: int |
|||
) -> List[Dict]: |
|||
""" |
|||
Ensures that the number of checkpoints stored are within the number |
|||
of checkpoints the user defines. If the limit is hit, checkpoints are |
|||
removed to create room for the next checkpoint to be inserted. |
|||
|
|||
:param behavior_name: The behavior name whose checkpoints we will mange. |
|||
:param keep_checkpoints: Number of checkpoints to record (user-defined). |
|||
""" |
|||
while len(checkpoints) > keep_checkpoints: |
|||
if keep_checkpoints <= 0 or len(checkpoints) == 0: |
|||
break |
|||
NNCheckpointManager.remove_checkpoint(checkpoints.pop(0)) |
|||
return checkpoints |
|||
|
|||
@classmethod |
|||
def add_checkpoint( |
|||
cls, behavior_name: str, new_checkpoint: NNCheckpoint, keep_checkpoints: int |
|||
) -> None: |
|||
""" |
|||
Make room for new checkpoint if needed and insert new checkpoint information. |
|||
:param behavior_name: Behavior name for the checkpoint. |
|||
:param new_checkpoint: The new checkpoint to be recorded. |
|||
:param keep_checkpoints: Number of checkpoints to record (user-defined). |
|||
""" |
|||
new_checkpoint_dict = attr.asdict(new_checkpoint) |
|||
checkpoints = cls.get_checkpoints(behavior_name) |
|||
checkpoints.append(new_checkpoint_dict) |
|||
cls._cleanup_extra_checkpoints(checkpoints, keep_checkpoints) |
|||
GlobalTrainingStatus.set_parameter_state( |
|||
behavior_name, StatusType.CHECKPOINTS, checkpoints |
|||
) |
|||
|
|||
@classmethod |
|||
def track_final_checkpoint( |
|||
cls, behavior_name: str, final_checkpoint: NNCheckpoint |
|||
) -> None: |
|||
""" |
|||
Ensures number of checkpoints stored is within the max number of checkpoints |
|||
defined by the user and finally stores the information about the final |
|||
model (or intermediate model if training is interrupted). |
|||
:param behavior_name: Behavior name of the model. |
|||
:param final_checkpoint: Checkpoint information for the final model. |
|||
""" |
|||
final_model_dict = attr.asdict(final_checkpoint) |
|||
GlobalTrainingStatus.set_parameter_state( |
|||
behavior_name, StatusType.FINAL_CHECKPOINT, final_model_dict |
|||
) |
|
|||
from mlagents.model_serialization import SerializationSettings |
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec |
|||
from mlagents.trainers.action_info import ActionInfo |
|||
from unittest.mock import MagicMock |
|||
from unittest import mock |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
import numpy as np |
|||
|
|||
|
|||
def basic_mock_brain(): |
|||
mock_brain = MagicMock() |
|||
mock_brain.vector_action_space_type = "continuous" |
|||
mock_brain.vector_observation_space_size = 1 |
|||
mock_brain.vector_action_space_size = [1] |
|||
mock_brain.brain_name = "MockBrain" |
|||
return mock_brain |
|||
|
|||
|
|||
class FakePolicy(TFPolicy): |
|||
def create_tf_graph(self): |
|||
pass |
|||
|
|||
def get_trainable_variables(self): |
|||
return [] |
|||
|
|||
|
|||
def test_take_action_returns_empty_with_no_agents(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
# Doesn't really matter what this is |
|||
dummy_groupspec = BehaviorSpec([(1,)], "continuous", 1) |
|||
no_agent_step = DecisionSteps.empty(dummy_groupspec) |
|||
result = policy.get_action(no_agent_step) |
|||
assert result == ActionInfo.empty() |
|||
|
|||
|
|||
def test_take_action_returns_nones_on_missing_values(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
policy.evaluate = MagicMock(return_value={}) |
|||
policy.save_memories = MagicMock() |
|||
step_with_agents = DecisionSteps( |
|||
[], np.array([], dtype=np.float32), np.array([0]), None |
|||
) |
|||
result = policy.get_action(step_with_agents, worker_id=0) |
|||
assert result == ActionInfo(None, None, {}, [0]) |
|||
|
|||
|
|||
def test_take_action_returns_action_info_when_available(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
policy_eval_out = { |
|||
"action": np.array([1.0], dtype=np.float32), |
|||
"memory_out": np.array([[2.5]], dtype=np.float32), |
|||
"value": np.array([1.1], dtype=np.float32), |
|||
} |
|||
policy.evaluate = MagicMock(return_value=policy_eval_out) |
|||
step_with_agents = DecisionSteps( |
|||
[], np.array([], dtype=np.float32), np.array([0]), None |
|||
) |
|||
result = policy.get_action(step_with_agents) |
|||
expected = ActionInfo( |
|||
policy_eval_out["action"], policy_eval_out["value"], policy_eval_out, [0] |
|||
) |
|||
assert result == expected |
|||
|
|||
|
|||
def test_convert_version_string(): |
|||
result = TFPolicy._convert_version_string("200.300.100") |
|||
assert result == (200, 300, 100) |
|||
# Test dev versions |
|||
result = TFPolicy._convert_version_string("200.300.100.dev0") |
|||
assert result == (200, 300, 100) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.policy.tf_policy.export_policy_model") |
|||
@mock.patch("time.time", mock.MagicMock(return_value=12345)) |
|||
def test_checkpoint_writes_tf_and_nn_checkpoints(export_policy_model_mock): |
|||
mock_brain = basic_mock_brain() |
|||
test_seed = 4 # moving up in the world |
|||
policy = FakePolicy(test_seed, mock_brain, TrainerSettings(), "output") |
|||
n_steps = 5 |
|||
policy.get_current_step = MagicMock(return_value=n_steps) |
|||
policy.saver = MagicMock() |
|||
serialization_settings = SerializationSettings("output", mock_brain.brain_name) |
|||
checkpoint_path = f"output/{mock_brain.brain_name}-{n_steps}" |
|||
policy.checkpoint(checkpoint_path, serialization_settings) |
|||
policy.saver.save.assert_called_once_with(policy.sess, f"{checkpoint_path}.ckpt") |
|||
export_policy_model_mock.assert_called_once_with( |
|||
checkpoint_path, serialization_settings, policy.graph, policy.sess |
|||
) |
|
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec |
|||
from mlagents.trainers.action_info import ActionInfo |
|||
from unittest.mock import MagicMock |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
import numpy as np |
|||
|
|||
|
|||
def basic_mock_brain(): |
|||
mock_brain = MagicMock() |
|||
mock_brain.vector_action_space_type = "continuous" |
|||
mock_brain.vector_observation_space_size = 1 |
|||
mock_brain.vector_action_space_size = [1] |
|||
return mock_brain |
|||
|
|||
|
|||
class FakePolicy(TFPolicy): |
|||
def create_tf_graph(self): |
|||
pass |
|||
|
|||
def get_trainable_variables(self): |
|||
return [] |
|||
|
|||
|
|||
def test_take_action_returns_empty_with_no_agents(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
# Doesn't really matter what this is |
|||
dummy_groupspec = BehaviorSpec([(1,)], "continuous", 1) |
|||
no_agent_step = DecisionSteps.empty(dummy_groupspec) |
|||
result = policy.get_action(no_agent_step) |
|||
assert result == ActionInfo.empty() |
|||
|
|||
|
|||
def test_take_action_returns_nones_on_missing_values(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
policy.evaluate = MagicMock(return_value={}) |
|||
policy.save_memories = MagicMock() |
|||
step_with_agents = DecisionSteps( |
|||
[], np.array([], dtype=np.float32), np.array([0]), None |
|||
) |
|||
result = policy.get_action(step_with_agents, worker_id=0) |
|||
assert result == ActionInfo(None, None, {}, [0]) |
|||
|
|||
|
|||
def test_take_action_returns_action_info_when_available(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
policy_eval_out = { |
|||
"action": np.array([1.0], dtype=np.float32), |
|||
"memory_out": np.array([[2.5]], dtype=np.float32), |
|||
"value": np.array([1.1], dtype=np.float32), |
|||
} |
|||
policy.evaluate = MagicMock(return_value=policy_eval_out) |
|||
step_with_agents = DecisionSteps( |
|||
[], np.array([], dtype=np.float32), np.array([0]), None |
|||
) |
|||
result = policy.get_action(step_with_agents) |
|||
expected = ActionInfo( |
|||
policy_eval_out["action"], policy_eval_out["value"], policy_eval_out, [0] |
|||
) |
|||
assert result == expected |
|||
|
|||
|
|||
def test_convert_version_string(): |
|||
result = TFPolicy._convert_version_string("200.300.100") |
|||
assert result == (200, 300, 100) |
|||
# Test dev versions |
|||
result = TFPolicy._convert_version_string("200.300.100.dev0") |
|||
assert result == (200, 300, 100) |
撰写
预览
正在加载...
取消
保存
Reference in new issue