|
|
|
|
|
|
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver |
|
|
|
from mlagents.trainers.settings import ( |
|
|
|
TrainerSettings, |
|
|
|
NetworkSettings, |
|
|
|
EncoderType, |
|
|
|
PPOSettings, |
|
|
|
SACSettings, |
|
|
|
POCASettings, |
|
|
|
|
|
|
_compare_two_policies(policy2, policy3) |
|
|
|
# Assert that the steps are 0. |
|
|
|
assert policy3.get_current_step() == 0 |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"]) |
|
|
|
def test_load_policy_different_hidden_units(tmp_path, vis_encode_type): |
|
|
|
path1 = os.path.join(tmp_path, "runid1") |
|
|
|
trainer_params = TrainerSettings() |
|
|
|
trainer_params.network_settings = NetworkSettings( |
|
|
|
hidden_units=12, vis_encode_type=EncoderType(vis_encode_type) |
|
|
|
) |
|
|
|
policy = create_policy_mock(trainer_params, use_visual=True) |
|
|
|
conv_params = [mod for mod in policy.actor.parameters() if len(mod.shape) > 2] |
|
|
|
|
|
|
|
model_saver = TorchModelSaver(trainer_params, path1) |
|
|
|
model_saver.register(policy) |
|
|
|
model_saver.initialize_or_load(policy) |
|
|
|
policy.set_step(2000) |
|
|
|
|
|
|
|
mock_brain_name = "MockBrain" |
|
|
|
model_saver.save_checkpoint(mock_brain_name, 2000) |
|
|
|
|
|
|
|
# Try load from this path |
|
|
|
trainer_params2 = TrainerSettings() |
|
|
|
trainer_params2.network_settings = NetworkSettings( |
|
|
|
hidden_units=10, vis_encode_type=EncoderType(vis_encode_type) |
|
|
|
) |
|
|
|
model_saver2 = TorchModelSaver(trainer_params2, path1, load=True) |
|
|
|
policy2 = create_policy_mock(trainer_params2, use_visual=True) |
|
|
|
conv_params2 = [mod for mod in policy2.actor.parameters() if len(mod.shape) > 2] |
|
|
|
# asserts convolutions have different parameters before load |
|
|
|
for conv1, conv2 in zip(conv_params, conv_params2): |
|
|
|
assert not torch.equal(conv1, conv2) |
|
|
|
model_saver2.register(policy2) |
|
|
|
model_saver2.initialize_or_load(policy2) |
|
|
|
# asserts convolutions have same parameters after load |
|
|
|
for conv1, conv2 in zip(conv_params, conv_params2): |
|
|
|
assert torch.equal(conv1, conv2) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|