浏览代码

test convolutions can be loaded properly

/fix-resume-imi
Andrew Cohen 4 年前
当前提交
98dcb548
共有 1 个文件被更改,包括 38 次插入0 次删除
  1. 38
      ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py

38
ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py


from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
from mlagents.trainers.settings import (
TrainerSettings,
NetworkSettings,
EncoderType,
PPOSettings,
SACSettings,
POCASettings,

_compare_two_policies(policy2, policy3)
# Assert that the steps are 0.
assert policy3.get_current_step() == 0
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
def test_load_policy_different_hidden_units(tmp_path, vis_encode_type):
path1 = os.path.join(tmp_path, "runid1")
trainer_params = TrainerSettings()
trainer_params.network_settings = NetworkSettings(
hidden_units=12, vis_encode_type=EncoderType(vis_encode_type)
)
policy = create_policy_mock(trainer_params, use_visual=True)
conv_params = [mod for mod in policy.actor.parameters() if len(mod.shape) > 2]
model_saver = TorchModelSaver(trainer_params, path1)
model_saver.register(policy)
model_saver.initialize_or_load(policy)
policy.set_step(2000)
mock_brain_name = "MockBrain"
model_saver.save_checkpoint(mock_brain_name, 2000)
# Try load from this path
trainer_params2 = TrainerSettings()
trainer_params2.network_settings = NetworkSettings(
hidden_units=10, vis_encode_type=EncoderType(vis_encode_type)
)
model_saver2 = TorchModelSaver(trainer_params2, path1, load=True)
policy2 = create_policy_mock(trainer_params2, use_visual=True)
conv_params2 = [mod for mod in policy2.actor.parameters() if len(mod.shape) > 2]
# asserts convolutions have different parameters before load
for conv1, conv2 in zip(conv_params, conv_params2):
assert not torch.equal(conv1, conv2)
model_saver2.register(policy2)
model_saver2.initialize_or_load(policy2)
# asserts convolutions have same parameters after load
for conv1, conv2 in zip(conv_params, conv_params2):
assert torch.equal(conv1, conv2)
@pytest.mark.parametrize(

正在加载...
取消
保存