浏览代码

update saver interface and add tests

/develop/add-fire/ckpt-2
Ruo-Ping Dong 5 年前
当前提交
95858e25
共有 18 个文件被更改,包括 326 次插入172 次删除
  1. 4
      ml-agents/mlagents/trainers/ghost/trainer.py
  2. 4
      ml-agents/mlagents/trainers/policy/tf_policy.py
  3. 13
      ml-agents/mlagents/trainers/ppo/trainer.py
  4. 14
      ml-agents/mlagents/trainers/sac/trainer.py
  5. 28
      ml-agents/mlagents/trainers/saver/saver.py
  6. 138
      ml-agents/mlagents/trainers/saver/tf_saver.py
  7. 99
      ml-agents/mlagents/trainers/saver/torch_saver.py
  8. 10
      ml-agents/mlagents/trainers/tests/test_bcmodule.py
  9. 9
      ml-agents/mlagents/trainers/tests/test_ppo.py
  10. 1
      ml-agents/mlagents/trainers/tests/test_reward_signals.py
  11. 12
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  12. 4
      ml-agents/mlagents/trainers/tests/test_sac.py
  13. 4
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  14. 9
      ml-agents/mlagents/trainers/tf/model_serialization.py
  15. 2
      ml-agents/mlagents/trainers/torch/model_serialization.py
  16. 32
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  17. 2
      ml-agents/mlagents/trainers/trainer/trainer.py
  18. 113
      ml-agents/mlagents/trainers/tests/test_saver.py

4
ml-agents/mlagents/trainers/ghost/trainer.py


"""
policy = self.trainer.create_policy(parsed_behavior_id, behavior_spec)
policy.create_tf_graph()
policy.initialize_or_load()
self.trainer.saver.initialize_or_load(policy)
policy.init_load_weights()
team_id = parsed_behavior_id.team_id
self.controller.subscribe_team_id(team_id, self)

self,
parsed_behavior_id: BehaviorIdentifiers,
policy: Policy,
create_saver: bool = True,
register_saver: bool = True,
) -> None:
"""
Adds policy to GhostTrainer.

4
ml-agents/mlagents/trainers/policy/tf_policy.py


# We do an initialize to make the Policy usable out of the box. If an optimizer is needed,
# it will re-load the full graph
self._initialize_graph()
self.initialize()
def _create_encoder(
self,

ver = LooseVersion(version_string)
return tuple(map(int, ver.version[0:3]))
def _initialize_graph(self):
def initialize(self):
with self.graph.as_default():
init = tf.global_variables_initializer()
self.sess.run(init)

13
ml-agents/mlagents/trainers/ppo/trainer.py


self,
parsed_behavior_id: BehaviorIdentifiers,
policy: Policy,
create_saver: bool = True,
register_saver: bool = True,
) -> None:
"""
Adds policy to trainer.

for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
if self.saver is None and create_saver:
self.saver = self.create_saver(
self.framework,
policy,
self.trainer_settings,
self.artifact_path,
self.load,
)
if register_saver:
self.saver.maybe_load()
self.saver.initialize_or_load(self.policy)
# Needed to resume loads properly
self.step = policy.get_current_step()

14
ml-agents/mlagents/trainers/sac/trainer.py


brain_name, trainer_settings, training, load, artifact_path, reward_buff_cap
)
self.load = load
self.seed = seed
self.policy: Policy = None # type: ignore
self.optimizer: SACOptimizer = None # type: ignore

self,
parsed_behavior_id: BehaviorIdentifiers,
policy: Policy,
create_saver: bool = True,
register_saver: bool = True,
) -> None:
"""
Adds policy to trainer.

for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
if self.saver is None and create_saver:
self.saver = self.create_saver(
self.framework,
policy,
self.trainer_settings,
self.artifact_path,
self.load,
)
if register_saver:
self.saver.maybe_load()
self.saver.initialize_or_load(self.policy)
# Needed to resume loads properly
self.step = policy.get_current_step()

28
ml-agents/mlagents/trainers/saver/saver.py


# # Unity ML-Agents Toolkit
import abc
from typing import Any
class BaseSaver(abc.ABC):

"""
TBA
"""
def register(self, module):
def register(self, module: Any) -> None:
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param brain_name: Brain name of brain to be trained
"""
def maybe_load(self):
def export(self, output_filepath: str, brain_name: str) -> None:
"""
Saves the serialized model, given a path and brain name.
This method will save the policy graph to the given filepath. The path
should be provided without an extension as multiple serialized model formats
may be generated as a result.
:param output_filepath: path (without suffix) for the model file(s)
:param brain_name: Brain name of brain to be trained.
"""
def export(self, output_filepath: str, brain_name: str) -> None:
def initialize_or_load(self, policy):
"""
If there is an initialize path, load from that. Else, load from the set model path.
If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
e.g., resume from an initialize path.
"""
pass

138
ml-agents/mlagents/trainers/saver/tf_saver.py


import os
import shutil
from typing import Optional, Union, cast
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.settings import TrainerSettings, SerializationSettings
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers import __version__

"""
def __init__(
self,
policy: TFPolicy,
trainer_settings: TrainerSettings,
model_path: str,
load: bool = False,
self, trainer_settings: TrainerSettings, model_path: str, load: bool = False
self.policy = policy
self.graph = self.policy.graph
self.sess = self.policy.sess
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
# Currently only support saving one policy. This is the one to be saved.
self.policy: Optional[TFPolicy] = None
self.graph = None
self.sess = None
self.tf_saver = None
def register(self, module):
pass
def register(self, module: Union[TFPolicy, TFOptimizer]) -> None:
if isinstance(module, TFPolicy):
if self.policy is None:
self.policy = module
self.graph = self.policy.graph
self.sess = self.policy.sess
with self.policy.graph.as_default():
self.tf_saver = tf.train.Saver(max_to_keep=self._keep_checkpoints)
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param brain_name: Brain name of brain to be trained
"""
with self.graph.as_default():
if self.saver:
self.saver.save(self.sess, f"{checkpoint_path}.ckpt")
tf.train.write_graph(
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
if self.graph:
with self.graph.as_default():
if self.tf_saver:
self.tf_saver.save(self.sess, f"{checkpoint_path}.ckpt")
tf.train.write_graph(
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
"""
Saves the serialized model, given a path and brain name.
export_policy_model(
self.model_path, output_filepath, brain_name, self.graph, self.sess
)
This method will save the policy graph to the given filepath. The path
should be provided without an extension as multiple serialized model formats
may be generated as a result.
:param output_filepath: path (without suffix) for the model file(s)
:param brain_name: Brain name of brain to be trained.
"""
export_policy_model(output_filepath, brain_name, self.graph, self.sess)
def maybe_load(self):
# If there is an initialize path, load from that. Else, load from the set model path.
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
# e.g., resume from an initialize path.
def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None:
# Initialize/Load registered self.policy by default.
# If given input argument policy, use the input policy instead.
# This argument is mainly for initialization of the ghost trainer's fixed policy.
if policy is None:
policy = self.policy
policy = cast(TFPolicy, policy)
self._load_graph(self.initialize_path, reset_global_steps=reset_steps)
self._load_graph(
policy, self.initialize_path, reset_global_steps=reset_steps
)
self._load_graph(self.model_path, reset_global_steps=reset_steps)
self._load_graph(policy, self.model_path, reset_global_steps=reset_steps)
self.policy._initialize_graph()
policy.initialize()
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
with self.graph.as_default():
def _load_graph(
self, policy: TFPolicy, model_path: str, reset_global_steps: bool = False
) -> None:
with policy.graph.as_default():
logger.info(f"Loading model from {model_path}.")
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:

"--run-id and that the previous run you are loading from had the same "
"behavior names.".format(model_path)
)
try:
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
except tf.errors.NotFoundError:
raise UnityPolicyException(
"The model {} was found but could not be loaded. Make "
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
"and is using the same trainer configuration as the current run.".format(
model_path
if self.tf_saver:
try:
self.tf_saver.restore(policy.sess, ckpt.model_checkpoint_path)
except tf.errors.NotFoundError:
raise UnityPolicyException(
"The model {} was found but could not be loaded. Make "
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
"and is using the same trainer configuration as the current run.".format(
model_path
)
)
self.policy.set_step(0)
policy.set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path

logger.info(
f"Resuming training from step {self.policy.get_current_step()}."
)
logger.info(f"Resuming training from step {policy.get_current_step()}.")
def _check_model_version(self, version: str) -> None:
"""

if self.policy.version_tensors is not None:
if self.policy is not None and self.policy.version_tensors is not None:
loaded_ver = tuple(
num.eval(session=self.sess) for num in self.policy.version_tensors
)

f"{loaded_ver[0]}.{loaded_ver[1]}.{loaded_ver[2]} but your current ML-Agents"
f"version is {version}. Model may not behave properly."
)
def copy_final_model(self, source_nn_path: str) -> None:
"""
Copy the .nn file at the given source to the destination.
Also copies the corresponding .onnx file if it exists.
"""
final_model_name = os.path.splitext(source_nn_path)[0]
if SerializationSettings.convert_to_barracuda:
source_path = f"{final_model_name}.nn"
destination_path = f"{self.model_path}.nn"
shutil.copyfile(source_path, destination_path)
logger.info(f"Copied {source_path} to {destination_path}.")
if SerializationSettings.convert_to_onnx:
try:
source_path = f"{final_model_name}.onnx"
destination_path = f"{self.model_path}.onnx"
shutil.copyfile(source_path, destination_path)
logger.info(f"Copied {source_path} to {destination_path}.")
except OSError:
pass

99
ml-agents/mlagents/trainers/saver/torch_saver.py


import os
import shutil
from typing import Dict
from typing import Dict, Union, Optional, cast
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.settings import TrainerSettings, SerializationSettings
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.torch.model_serialization import ModelSerializer

"""
def __init__(
self,
policy: TorchPolicy,
trainer_settings: TrainerSettings,
model_path: str,
load: bool = False,
self, trainer_settings: TrainerSettings, model_path: str, load: bool = False
self.policy = policy
self.exporter = ModelSerializer(self.policy)
self.policy: Optional[TorchPolicy] = None
self.exporter: Optional[ModelSerializer] = None
def register(self, module):
self.modules.update(module.get_modules())
def register(self, module: Union[TorchPolicy, TorchOptimizer]) -> None:
self.modules.update(module.get_modules()) # type: ignore
if self.policy is None and isinstance(module, TorchPolicy):
self.policy = module
self.exporter = ModelSerializer(self.policy)
"""
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param brain_name: Brain name of brain to be trained
"""
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}")

self.export(checkpoint_path, brain_name)
return checkpoint_path
def maybe_load(self):
# If there is an initialize path, load from that. Else, load from the set model path.
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
# e.g., resume from an initialize path.
def export(self, output_filepath: str, brain_name: str) -> None:
if self.exporter is not None:
self.exporter.export_policy_model(output_filepath)
def initialize_or_load(self, policy: Optional[TorchPolicy] = None) -> None:
# Initialize/Load registered self.policy by default.
# If given input argument policy, use the input policy instead.
# This argument is mainly for initialization of the ghost trainer's fixed policy.
self._load_model(self.initialize_path, reset_global_steps=reset_steps)
self._load_model(
self.initialize_path, policy, reset_global_steps=reset_steps
)
self._load_model(self.model_path, reset_global_steps=reset_steps)
self._load_model(self.model_path, policy, reset_global_steps=reset_steps)
def export(self, output_filepath: str, brain_name: str) -> None:
self.exporter.export_policy_model(output_filepath)
def _load_model(self, load_path: str, reset_global_steps: bool = False) -> None:
def _load_model(
self,
load_path: str,
policy: Optional[TorchPolicy] = None,
reset_global_steps: bool = False,
) -> None:
for name, state_dict in saved_state_dict.items():
self.modules[name].load_state_dict(state_dict)
if policy is None:
modules = self.modules
policy = self.policy
else:
modules = policy.get_modules()
policy = cast(TorchPolicy, policy)
for name, mod in modules.items():
mod.load_state_dict(saved_state_dict[name])
self.policy.set_step(0)
policy.set_step(0)
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path

logger.info(
f"Resuming training from step {self.policy.get_current_step()}."
)
logger.info(f"Resuming training from step {policy.get_current_step()}.")
def copy_final_model(self, source_nn_path: str) -> None:
"""
Copy the .nn file at the given source to the destination.
Also copies the corresponding .onnx file if it exists.
"""
final_model_name = os.path.splitext(source_nn_path)[0]
if SerializationSettings.convert_to_barracuda:
source_path = f"{final_model_name}.nn"
destination_path = f"{self.model_path}.nn"
shutil.copyfile(source_path, destination_path)
logger.info(f"Copied {source_path} to {destination_path}.")
if SerializationSettings.convert_to_onnx:
try:
source_path = f"{final_model_name}.onnx"
destination_path = f"{self.model_path}.onnx"
shutil.copyfile(source_path, destination_path)
logger.info(f"Copied {source_path} to {destination_path}.")
except OSError:
pass

10
ml-agents/mlagents/trainers/tests/test_bcmodule.py


NetworkSettings.MemorySettings() if use_rnn else None
)
policy = TFPolicy(
0,
mock_behavior_specs,
trainer_config,
"test",
False,
tanhresample,
tanhresample,
0, mock_behavior_specs, trainer_config, tanhresample, tanhresample
)
with policy.graph.as_default():
bc_module = BCModule(

default_num_epoch=3,
settings=bc_settings,
)
policy.initialize_or_load() # Normally the optimizer calls this after the BCModule is created
policy.initialize() # Normally the optimizer calls this after the BCModule is created
return bc_module

9
ml-agents/mlagents/trainers/tests/test_ppo.py


0, mock_specs, trainer_settings, "test", False, create_tf_graph=False
)
optimizer = PPOOptimizer(policy, trainer_settings)
policy.initialize()
return optimizer

)
@mock.patch("mlagents.trainers.ppo.trainer.TFPPOOptimizer")
@mock.patch("mlagents.trainers.ppo.trainer.PPOOptimizer")
def test_trainer_increment_step(ppo_optimizer, dummy_config):
trainer_params = PPO_CONFIG
mock_optimizer = mock.Mock()

)
policy_mock.increment_step = mock.Mock(return_value=step_count)
behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name)
trainer.add_policy(behavior_id, policy_mock, create_saver=False)
trainer.add_policy(behavior_id, policy_mock, register_saver=False)
trainer._increment_step(5, trainer.brain_name)
policy_mock.increment_step.assert_called_with(5)

assert trainer.stats_reporter.get_stats_summaries("Policy/Extrinsic Reward").num > 0
@mock.patch("mlagents.trainers.ppo.trainer.TFPPOOptimizer")
@mock.patch("mlagents.trainers.ppo.trainer.PPOOptimizer")
def test_add_get_policy(ppo_optimizer, dummy_config):
mock_optimizer = mock.Mock()
mock_optimizer.reward_signals = {}

policy.get_current_step.return_value = 2000
behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name)
trainer.add_policy(behavior_id, policy, create_saver=False)
trainer.add_policy(behavior_id, policy, register_saver=False)
assert trainer.get_policy("test_policy") == policy
# Make sure the summary steps were loaded properly

1
ml-agents/mlagents/trainers/tests/test_reward_signals.py


optimizer = SACOptimizer(policy, trainer_settings)
else:
optimizer = PPOOptimizer(policy, trainer_settings)
optimizer.policy.initialize()
return optimizer

12
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


def _update_policy(self):
return self.update_policy
def add_policy(self, mock_behavior_id, mock_policy, create_saver=True):
def add_policy(self, mock_behavior_id, mock_policy):
def checkpoint_path(brain_name, step):
return os.path.join(self.saver.model_path, f"{brain_name}-{step}")

mock_saver.save_checkpoint.side_effect = checkpoint_path
self.saver = mock_saver
def create_tf_policy(self):
def create_tf_policy(self, parsed_behavior_id, behavior_spec):
def create_torch_policy(self):
def create_torch_policy(self, parsed_behavior_id, behavior_spec):
def create_torch_policy(self, parsed_behavior_id, behavior_spec):
return mock.Mock()
def create_tf_policy(self, parsed_behavior_id, behavior_spec):
return mock.Mock()
def create_rl_trainer():

4
ml-agents/mlagents/trainers/tests/test_sac.py


0, mock_brain, trainer_settings, "test", False, create_tf_graph=False
)
optimizer = SACOptimizer(policy, trainer_settings)
optimizer.policy.initialize()
return optimizer

policy = mock.Mock(spec=TFPolicy)
policy.get_current_step.return_value = 2000
behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name)
trainer.add_policy(behavior_id, policy, create_saver=False)
trainer.add_policy(behavior_id, policy, register_saver=False)
assert trainer.get_policy(behavior_id.behavior_id) == policy
# Make sure the summary steps were loaded properly

policy.get_current_step = lambda: 200
trainer.add_policy(behavior_id, policy)
trainer.optimizer.update = mock.Mock()
trainer.saver.initialize_or_load(policy)
trainer.optimizer.update_reward_signals = mock.Mock()
trainer.optimizer.update_reward_signals.return_value = {}
trainer.optimizer.update.return_value = {}

4
ml-agents/mlagents/trainers/tests/test_simple_rl.py


@pytest.mark.parametrize("use_discrete", [True, False])
def test_recurrent_sac(use_discrete):
step_size = 0.2 if use_discrete else 1.0
step_size = 0.5 if use_discrete else 0.2
env = MemoryEnvironment(
[BRAIN_NAME], use_discrete=use_discrete, step_size=step_size
)

swap_steps=5000,
team_change=2000,
)
config = attr.evolve(PPO_CONFIG, self_play=self_play_settings, max_steps=2000)
config = attr.evolve(PPO_CONFIG, self_play=self_play_settings, max_steps=3000)
_check_environment_trains(
env, {BRAIN_NAME: config, brain_name_opp: config}, success_threshold=None
)

9
ml-agents/mlagents/trainers/tf/model_serialization.py


def export_policy_model(
output_filepath: str, brain_name: str, graph: tf.Graph, sess: tf.Session
model_path: str,
output_filepath: str,
brain_name: str,
graph: tf.Graph,
sess: tf.Session,
:param brain_name: brain name of the trained model
:param graph: Tensorflow Graph for the policy
:param sess: Tensorflow session for the policy
"""

# Save frozen graph
frozen_graph_def_path = output_filepath + "/frozen_graph_def.pb"
frozen_graph_def_path = model_path + "/frozen_graph_def.pb"
with gfile.GFile(frozen_graph_def_path, "wb") as f:
f.write(frozen_graph_def.SerializeToString())

2
ml-agents/mlagents/trainers/torch/model_serialization.py


output_names=self.output_names,
dynamic_axes=self.dynamic_axes,
)
logger.info(f"Exported {onnx_output_path}.onnx")
logger.info(f"Exported {onnx_output_path}")

32
ml-agents/mlagents/trainers/trainer/rl_trainer.py


import abc
import time
import attr
from mlagents.model_serialization import copy_model_files
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpoint,
NNCheckpointManager,

from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.settings import TestingConfiguration, TrainerSettings, FrameworkType
from mlagents.trainers.settings import (
TestingConfiguration,
TrainerSettings,
FrameworkType,
)
from mlagents.trainers.stats import StatsPropertyType
from mlagents.trainers.saver.saver import BaseSaver
from mlagents.trainers.saver.torch_saver import TorchSaver

self.trainer_settings.max_steps = TestingConfiguration.max_steps
self._next_save_step = 0
self._next_summary_step = 0
self.saver = None
self.saver = self.create_saver(
self.framework, self.trainer_settings, self.artifact_path, self.load
)
def end_episode(self) -> None:
"""

@staticmethod
def create_saver(
framework: str,
policy: Policy,
trainer_settings: TrainerSettings,
model_path: str,
load: bool,
framework: str, trainer_settings: TrainerSettings, model_path: str, load: bool
policy, # type: ignore
trainer_settings,
model_path,
load,
trainer_settings, model_path, load
policy, # type: ignore
trainer_settings,
model_path,
load,
trainer_settings, model_path, load
)
return saver

return
model_checkpoint = self._checkpoint()
# Copy the checkpointed model files to the final output location
copy_model_files(model_checkpoint.file_path, f"{policy.model_path}.nn")
self.saver.copy_final_model(model_checkpoint.file_path)
final_checkpoint = attr.evolve(
model_checkpoint, file_path=f"{self.saver.model_path}.nn"
)

2
ml-agents/mlagents/trainers/trainer/trainer.py


self,
parsed_behavior_id: BehaviorIdentifiers,
policy: Policy,
create_saver: bool = True,
register_saver: bool = True,
) -> None:
"""
Adds policy to trainer.

113
ml-agents/mlagents/trainers/tests/test_saver.py


import pytest
from unittest import mock
import os
import unittest
import tempfile
import numpy as np
from mlagents.tf_utils import tf
from mlagents.trainers.saver.tf_saver import TFSaver
from mlagents.trainers import __version__
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.test_nn_policy import create_policy_mock
from mlagents.trainers.ppo.optimizer_tf import PPOOptimizer
def test_register(tmp_path):
trainer_params = TrainerSettings()
saver = TFSaver(trainer_params, tmp_path)
opt = mock.Mock(spec=PPOOptimizer)
saver.register(opt)
assert saver.policy is None
trainer_params = TrainerSettings()
policy = create_policy_mock(trainer_params)
saver.register(policy)
assert saver.policy is not None
class ModelVersionTest(unittest.TestCase):
def test_version_compare(self):
# Test write_stats
with self.assertLogs("mlagents.trainers", level="WARNING") as cm:
trainer_params = TrainerSettings()
mock_path = tempfile.mkdtemp()
policy = create_policy_mock(trainer_params)
saver = TFSaver(trainer_params, mock_path)
saver.register(policy)
saver._check_model_version(
"0.0.0"
) # This is not the right version for sure
# Assert that 1 warning has been thrown with incorrect version
assert len(cm.output) == 1
saver._check_model_version(__version__) # This should be the right version
# Assert that no additional warnings have been thrown wth correct ver
assert len(cm.output) == 1
def test_load_save(tmp_path):
path1 = os.path.join(tmp_path, "runid1")
path2 = os.path.join(tmp_path, "runid2")
trainer_params = TrainerSettings()
policy = create_policy_mock(trainer_params)
saver = TFSaver(trainer_params, path1)
saver.register(policy)
saver.initialize_or_load(policy)
policy.set_step(2000)
mock_brain_name = "MockBrain"
saver.save_checkpoint(mock_brain_name, 2000)
assert len(os.listdir(tmp_path)) > 0
# Try load from this path
saver = TFSaver(trainer_params, path1, load=True)
policy2 = create_policy_mock(trainer_params)
saver.register(policy2)
saver.initialize_or_load(policy2)
_compare_two_policies(policy, policy2)
assert policy2.get_current_step() == 2000
# Try initialize from path 1
trainer_params.init_path = path1
saver = TFSaver(trainer_params, path2)
policy3 = create_policy_mock(trainer_params)
saver.register(policy3)
saver.initialize_or_load(policy3)
_compare_two_policies(policy2, policy3)
# Assert that the steps are 0.
assert policy3.get_current_step() == 0
def _compare_two_policies(policy1: TFPolicy, policy2: TFPolicy) -> None:
"""
Make sure two policies have the same output for the same input.
"""
decision_step, _ = mb.create_steps_from_behavior_spec(
policy1.behavior_spec, num_agents=1
)
run_out1 = policy1.evaluate(decision_step, list(decision_step.agent_id))
run_out2 = policy2.evaluate(decision_step, list(decision_step.agent_id))
np.testing.assert_array_equal(run_out2["log_probs"], run_out1["log_probs"])
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_checkpoint_conversion(tmpdir, rnn, visual, discrete):
tf.reset_default_graph()
dummy_config = TrainerSettings()
model_path = os.path.join(tmpdir, "Mock_Brain")
policy = create_policy_mock(
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
trainer_params = TrainerSettings()
saver = TFSaver(trainer_params, model_path)
saver.register(policy)
saver.save_checkpoint("Mock_Brain", 100)
assert os.path.isfile(model_path + "/Mock_Brain-100.nn")
正在加载...
取消
保存