浏览代码

Rename saver

/MLA-1734-demo-provider
Ruo-Ping Dong 5 年前
当前提交
c47ffc20
共有 15 个文件被更改,包括 71 次插入70 次删除
  1. 2
      ml-agents/mlagents/trainers/ghost/trainer.py
  2. 6
      ml-agents/mlagents/trainers/ppo/trainer.py
  3. 6
      ml-agents/mlagents/trainers/sac/trainer.py
  4. 8
      ml-agents/mlagents/trainers/tests/test_ppo.py
  5. 14
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  6. 7
      ml-agents/mlagents/trainers/tests/test_sac.py
  7. 48
      ml-agents/mlagents/trainers/tests/test_saver.py
  8. 24
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  9. 12
      ml-agents/mlagents/trainers/model_saver/model_saver.py
  10. 6
      ml-agents/mlagents/trainers/model_saver/tf_model_saver.py
  11. 8
      ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
  12. 0
      /ml-agents/mlagents/trainers/model_saver
  13. 0
      /ml-agents/mlagents/trainers/model_saver/model_saver.py
  14. 0
      /ml-agents/mlagents/trainers/model_saver/tf_model_saver.py
  15. 0
      /ml-agents/mlagents/trainers/model_saver/torch_model_saver.py

2
ml-agents/mlagents/trainers/ghost/trainer.py


policy = self.trainer.create_policy(
parsed_behavior_id, behavior_spec, create_graph=True
)
self.trainer.saver.initialize_or_load(policy)
self.trainer.model_saver.initialize_or_load(policy)
team_id = parsed_behavior_id.team_id
self.controller.subscribe_team_id(team_id, self)

6
ml-agents/mlagents/trainers/ppo/trainer.py


for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
self.saver.register(self.policy)
self.saver.register(self.optimizer)
self.saver.initialize_or_load()
self.model_saver.register(self.policy)
self.model_saver.register(self.optimizer)
self.model_saver.initialize_or_load()
# Needed to resume loads properly
self.step = policy.get_current_step()

6
ml-agents/mlagents/trainers/sac/trainer.py


for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
self.saver.register(self.policy)
self.saver.register(self.optimizer)
self.saver.initialize_or_load()
self.model_saver.register(self.policy)
self.model_saver.register(self.optimizer)
self.model_saver.initialize_or_load()
# Needed to resume loads properly
self.step = policy.get_current_step()

8
ml-agents/mlagents/trainers/tests/test_ppo.py


)
@mock.patch.object(RLTrainer, "create_saver")
@mock.patch.object(RLTrainer, "create_model_saver")
def test_trainer_increment_step(ppo_optimizer, mock_create_saver):
def test_trainer_increment_step(ppo_optimizer, mock_create_model_saver):
trainer_params = PPO_CONFIG
mock_optimizer = mock.Mock()
mock_optimizer.reward_signals = {}

assert trainer.stats_reporter.get_stats_summaries("Policy/Extrinsic Reward").num > 0
@mock.patch.object(RLTrainer, "create_saver")
@mock.patch.object(RLTrainer, "create_model_saver")
def test_add_get_policy(ppo_optimizer, mock_create_saver, dummy_config):
def test_add_get_policy(ppo_optimizer, mock_create_model_saver, dummy_config):
mock_optimizer = mock.Mock()
mock_optimizer.reward_signals = {}
ppo_optimizer.return_value = mock_optimizer

14
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


def add_policy(self, mock_behavior_id, mock_policy):
def checkpoint_path(brain_name, step):
return os.path.join(self.saver.model_path, f"{brain_name}-{step}")
return os.path.join(self.model_saver.model_path, f"{brain_name}-{step}")
mock_saver = mock.Mock()
mock_saver.model_path = self.artifact_path
mock_saver.save_checkpoint.side_effect = checkpoint_path
self.saver = mock_saver
mock_model_saver = mock.Mock()
mock_model_saver.model_path = self.artifact_path
mock_model_saver.save_checkpoint.side_effect = checkpoint_path
self.model_saver = mock_model_saver
def create_tf_policy(self, parsed_behavior_id, behavior_spec):
return mock.Mock()

checkpoint_interval, num_trajectories * time_horizon, checkpoint_interval
)
calls = [mock.call(trainer.brain_name, step) for step in checkpoint_range]
trainer.saver.save_checkpoint.assert_has_calls(calls, any_order=True)
trainer.model_saver.save_checkpoint.assert_has_calls(calls, any_order=True)
add_checkpoint_calls = [
mock.call(

f"{trainer.saver.model_path}/{trainer.brain_name}-{step}.nn",
f"{trainer.model_saver.model_path}/{trainer.brain_name}-{step}.nn",
None,
mock.ANY,
),

7
ml-agents/mlagents/trainers/tests/test_sac.py


assert trainer2.update_buffer.num_experiences == buffer_len
@mock.patch.object(RLTrainer, "create_saver")
@mock.patch.object(RLTrainer, "create_model_saver")
def test_add_get_policy(sac_optimizer, mock_create_saver, dummy_config):
def test_add_get_policy(sac_optimizer, mock_create_model_saver, dummy_config):
mock_optimizer = mock.Mock()
mock_optimizer.reward_signals = {}
sac_optimizer.return_value = mock_optimizer

policy = trainer.create_policy(behavior_id, specs)
policy.get_current_step = lambda: 200
trainer.add_policy(behavior_id, policy)
trainer.saver.initialize_or_load(policy)
trainer.saver.initialize_or_load(policy)
trainer.model_saver.initialize_or_load(policy)
trainer.optimizer.update_reward_signals = mock.Mock()
trainer.optimizer.update_reward_signals.return_value = {}
trainer.optimizer.update.return_value = {}

48
ml-agents/mlagents/trainers/tests/test_saver.py


import numpy as np
from mlagents.tf_utils import tf
from mlagents.trainers.saver.tf_saver import TFSaver
from mlagents.trainers.model_saver.tf_model_saver import TFModelSaver
from mlagents.trainers import __version__
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.policy.tf_policy import TFPolicy

def test_register(tmp_path):
trainer_params = TrainerSettings()
saver = TFSaver(trainer_params, tmp_path)
model_saver = TFModelSaver(trainer_params, tmp_path)
saver.register(opt)
assert saver.policy is None
model_saver.register(opt)
assert model_saver.policy is None
saver.register(policy)
assert saver.policy is not None
model_saver.register(policy)
assert model_saver.policy is not None
class ModelVersionTest(unittest.TestCase):

trainer_params = TrainerSettings()
mock_path = tempfile.mkdtemp()
policy = create_policy_mock(trainer_params)
saver = TFSaver(trainer_params, mock_path)
saver.register(policy)
model_saver = TFModelSaver(trainer_params, mock_path)
model_saver.register(policy)
saver._check_model_version(
model_saver._check_model_version(
saver._check_model_version(__version__) # This should be the right version
model_saver._check_model_version(
__version__
) # This should be the right version
# Assert that no additional warnings have been thrown wth correct ver
assert len(cm.output) == 1

path2 = os.path.join(tmp_path, "runid2")
trainer_params = TrainerSettings()
policy = create_policy_mock(trainer_params)
saver = TFSaver(trainer_params, path1)
saver.register(policy)
saver.initialize_or_load(policy)
model_saver = TFModelSaver(trainer_params, path1)
model_saver.register(policy)
model_saver.initialize_or_load(policy)
saver.save_checkpoint(mock_brain_name, 2000)
model_saver.save_checkpoint(mock_brain_name, 2000)
saver = TFSaver(trainer_params, path1, load=True)
model_saver = TFModelSaver(trainer_params, path1, load=True)
saver.register(policy2)
saver.initialize_or_load(policy2)
model_saver.register(policy2)
model_saver.initialize_or_load(policy2)
saver = TFSaver(trainer_params, path2)
model_saver = TFModelSaver(trainer_params, path2)
saver.register(policy3)
saver.initialize_or_load(policy3)
model_saver.register(policy3)
model_saver.initialize_or_load(policy3)
_compare_two_policies(policy2, policy3)
# Assert that the steps are 0.

dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
trainer_params = TrainerSettings()
saver = TFSaver(trainer_params, model_path)
saver.register(policy)
saver.save_checkpoint("Mock_Brain", 100)
model_saver = TFModelSaver(trainer_params, model_path)
model_saver.register(policy)
model_saver.save_checkpoint("Mock_Brain", 100)
assert os.path.isfile(model_path + "/Mock_Brain-100.nn")

24
ml-agents/mlagents/trainers/trainer/rl_trainer.py


from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.settings import TrainerSettings, FrameworkType
from mlagents.trainers.stats import StatsPropertyType
from mlagents.trainers.saver.saver import BaseSaver
from mlagents.trainers.saver.tf_saver import TFSaver
from mlagents.trainers.model_saver.model_saver import BaseModelSaver
from mlagents.trainers.model_saver.tf_model_saver import TFModelSaver
from mlagents.trainers.saver.torch_saver import TorchSaver
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
except ModuleNotFoundError:
TorchPolicy = None # type: ignore

self._next_save_step = 0
self._next_summary_step = 0
self.saver = self.create_saver(
self.model_saver = self.create_model_saver(
self.framework, self.trainer_settings, self.artifact_path, self.load
)

pass
@staticmethod
def create_saver(
def create_model_saver(
) -> BaseSaver:
) -> BaseModelSaver:
saver = TorchSaver( # type: ignore
model_saver = TorchModelSaver( # type: ignore
saver = TFSaver( # type: ignore
model_saver = TFModelSaver( # type: ignore
return saver
return model_saver
def _policy_mean_reward(self) -> Optional[float]:
""" Returns the mean episode reward for the current policy. """

logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
checkpoint_path = self.saver.save_checkpoint(self.brain_name, self.step)
checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self.step)
new_checkpoint = NNCheckpoint(
int(self.step),
f"{checkpoint_path}.nn",

return
model_checkpoint = self._checkpoint()
self.saver.copy_final_model(model_checkpoint.file_path)
self.model_saver.copy_final_model(model_checkpoint.file_path)
model_checkpoint, file_path=f"{self.saver.model_path}.nn"
model_checkpoint, file_path=f"{self.model_saver.model_path}.nn"
)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)

12
ml-agents/mlagents/trainers/model_saver/model_saver.py


from typing import Any
class BaseSaver(abc.ABC):
"""This class is the base class for the Saver"""
class BaseModelSaver(abc.ABC):
"""This class is the base class for the ModelSaver"""
def __init__(self):
pass

"""
Register the modules to the Saver.
The Saver will store the module and include it in the saved files
Register the modules to the ModelSaver.
The ModelSaver will store the module and include it in the saved files
when saving checkpoint/exporting graph.
:param module: the module to be registered
"""

"""
Helper function for registering policy to the Saver.
Helper function for registering policy to the ModelSaver.
:param policy: the policy to be registered
"""
pass

Helper function for registering optimizer to the Saver.
Helper function for registering optimizer to the ModelSaver.
:param optimizer: the optimizer to be registered
"""
pass

6
ml-agents/mlagents/trainers/model_saver/tf_model_saver.py


from mlagents_envs.exception import UnityPolicyException
from mlagents_envs.logging_util import get_logger
from mlagents.tf_utils import tf
from mlagents.trainers.saver.saver import BaseSaver
from mlagents.trainers.model_saver.model_saver import BaseModelSaver
from mlagents.trainers.tf.model_serialization import export_policy_model
from mlagents.trainers.settings import TrainerSettings, SerializationSettings
from mlagents.trainers.policy.tf_policy import TFPolicy

logger = get_logger(__name__)
class TFSaver(BaseSaver):
class TFModelSaver(BaseModelSaver):
Saver class for TensorFlow
ModelSaver class for TensorFlow
"""
def __init__(

8
ml-agents/mlagents/trainers/model_saver/torch_model_saver.py


from typing import Dict, Union, Optional, cast
from mlagents_envs.exception import UnityPolicyException
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.saver.saver import BaseSaver
from mlagents.trainers.model_saver.model_saver import BaseModelSaver
from mlagents.trainers.settings import TrainerSettings, SerializationSettings
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer

logger = get_logger(__name__)
class TorchSaver(BaseSaver):
class TorchModelSaver(BaseModelSaver):
Saver class for PyTorch
ModelSaver class for PyTorch
"""
def __init__(

self.modules.update(module.get_modules()) # type: ignore
else:
raise UnityPolicyException(
"Registering Object of unsupported type {} to Saver ".format(
"Registering Object of unsupported type {} to ModelSaver ".format(
type(module)
)
)

/ml-agents/mlagents/trainers/saver → /ml-agents/mlagents/trainers/model_saver

/ml-agents/mlagents/trainers/model_saver/saver.py → /ml-agents/mlagents/trainers/model_saver/model_saver.py

/ml-agents/mlagents/trainers/model_saver/tf_saver.py → /ml-agents/mlagents/trainers/model_saver/tf_model_saver.py

/ml-agents/mlagents/trainers/model_saver/torch_saver.py → /ml-agents/mlagents/trainers/model_saver/torch_model_saver.py

正在加载...
取消
保存