浏览代码

update torch saver test

/MLA-1734-demo-provider
Ruo-Ping Dong 4 年前
当前提交
07e82899
共有 1 个文件被更改,包括 19 次插入19 次删除
  1. 38
      ml-agents/mlagents/trainers/tests/torch/test_saver.py

38
ml-agents/mlagents/trainers/tests/torch/test_saver.py


import torch
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
from mlagents.trainers.saver.torch_saver import TorchSaver
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.torch.test_policy import create_policy_mock

trainer_params = TrainerSettings()
saver = TorchSaver(trainer_params, tmp_path)
model_saver = TorchModelSaver(trainer_params, tmp_path)
saver.register(opt)
assert saver.policy is None
model_saver.register(opt)
assert model_saver.policy is None
saver.register(policy)
assert saver.policy is not None
model_saver.register(policy)
assert model_saver.policy is not None
def test_load_save(tmp_path):

policy = create_policy_mock(trainer_params)
saver = TorchSaver(trainer_params, path1)
saver.register(policy)
saver.initialize_or_load(policy)
model_saver = TorchModelSaver(trainer_params, path1)
model_saver.register(policy)
model_saver.initialize_or_load(policy)
saver.save_checkpoint(mock_brain_name, 2000)
model_saver.save_checkpoint(mock_brain_name, 2000)
saver2 = TorchSaver(trainer_params, path1, load=True)
model_saver2 = TorchModelSaver(trainer_params, path1, load=True)
saver2.register(policy2)
saver2.initialize_or_load(policy2)
model_saver2.register(policy2)
model_saver2.initialize_or_load(policy2)
saver3 = TorchSaver(trainer_params, path2)
model_saver3 = TorchModelSaver(trainer_params, path2)
saver3.register(policy3)
saver3.initialize_or_load(policy3)
model_saver3.register(policy3)
model_saver3.initialize_or_load(policy3)
_compare_two_policies(policy2, policy3)
# Assert that the steps are 0.
assert policy3.get_current_step() == 0

dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
trainer_params = TrainerSettings()
saver = TorchSaver(trainer_params, model_path)
saver.register(policy)
saver.save_checkpoint("Mock_Brain", 100)
model_saver = TorchModelSaver(trainer_params, model_path)
model_saver.register(policy)
model_saver.save_checkpoint("Mock_Brain", 100)
assert os.path.isfile(model_path + "/Mock_Brain-100.onnx")
正在加载...
取消
保存