|
|
|
|
|
|
use_rnn: bool = False, |
|
|
|
use_discrete: bool = True, |
|
|
|
use_visual: bool = False, |
|
|
|
model_path: str = "", |
|
|
|
load: bool = False, |
|
|
|
seed: int = 0, |
|
|
|
) -> NNPolicy: |
|
|
|
|
|
|
trainer_settings.network_settings.memory = ( |
|
|
|
NetworkSettings.MemorySettings() if use_rnn else None |
|
|
|
) |
|
|
|
policy = NNPolicy(seed, mock_brain, trainer_settings, False, load) |
|
|
|
policy = NNPolicy(seed, mock_brain, trainer_settings, False, model_path, load) |
|
|
|
return policy |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer_params = TrainerSettings(output_path=path1) |
|
|
|
policy = create_policy_mock(trainer_params) |
|
|
|
trainer_params = TrainerSettings() |
|
|
|
policy = create_policy_mock(trainer_params, model_path=path1) |
|
|
|
policy.initialize_or_load() |
|
|
|
policy._set_step(2000) |
|
|
|
policy.save_model(2000) |
|
|
|
|
|
|
# Try load from this path |
|
|
|
policy2 = create_policy_mock(trainer_params, load=True, seed=1) |
|
|
|
policy2 = create_policy_mock(trainer_params, model_path=path1, load=True, seed=1) |
|
|
|
policy2.initialize_or_load() |
|
|
|
_compare_two_policies(policy, policy2) |
|
|
|
assert policy2.get_current_step() == 2000 |
|
|
|
|
|
|
trainer_params.init_path = path1 |
|
|
|
policy3 = create_policy_mock(trainer_params, load=False, seed=2) |
|
|
|
policy3 = create_policy_mock(trainer_params, model_path=path1, load=False, seed=2) |
|
|
|
policy3.initialize_or_load() |
|
|
|
|
|
|
|
_compare_two_policies(policy2, policy3) |
|
|
|
|
|
|
# Test write_stats |
|
|
|
with self.assertLogs("mlagents.trainers", level="WARNING") as cm: |
|
|
|
path1 = tempfile.mkdtemp() |
|
|
|
trainer_params = TrainerSettings(output_path=path1) |
|
|
|
policy = create_policy_mock(trainer_params) |
|
|
|
trainer_params = TrainerSettings() |
|
|
|
policy = create_policy_mock(trainer_params, model_path=path1) |
|
|
|
policy.initialize_or_load() |
|
|
|
policy._check_model_version( |
|
|
|
"0.0.0" |
|
|
|
|
|
|
brain_params, |
|
|
|
TrainerSettings(network_settings=NetworkSettings(normalize=True)), |
|
|
|
False, |
|
|
|
"testdir", |
|
|
|
False, |
|
|
|
) |
|
|
|
|
|
|
|