|
|
|
|
|
|
import torch |
|
|
|
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|
|
|
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer |
|
|
|
from mlagents.trainers.saver.torch_saver import TorchSaver |
|
|
|
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver |
|
|
|
from mlagents.trainers.settings import TrainerSettings |
|
|
|
from mlagents.trainers.tests import mock_brain as mb |
|
|
|
from mlagents.trainers.tests.torch.test_policy import create_policy_mock |
|
|
|
|
|
|
trainer_params = TrainerSettings() |
|
|
|
saver = TorchSaver(trainer_params, tmp_path) |
|
|
|
model_saver = TorchModelSaver(trainer_params, tmp_path) |
|
|
|
saver.register(opt) |
|
|
|
assert saver.policy is None |
|
|
|
model_saver.register(opt) |
|
|
|
assert model_saver.policy is None |
|
|
|
saver.register(policy) |
|
|
|
assert saver.policy is not None |
|
|
|
model_saver.register(policy) |
|
|
|
assert model_saver.policy is not None |
|
|
|
|
|
|
|
|
|
|
|
def test_load_save(tmp_path): |
|
|
|
|
|
|
policy = create_policy_mock(trainer_params) |
|
|
|
saver = TorchSaver(trainer_params, path1) |
|
|
|
saver.register(policy) |
|
|
|
saver.initialize_or_load(policy) |
|
|
|
model_saver = TorchModelSaver(trainer_params, path1) |
|
|
|
model_saver.register(policy) |
|
|
|
model_saver.initialize_or_load(policy) |
|
|
|
saver.save_checkpoint(mock_brain_name, 2000) |
|
|
|
model_saver.save_checkpoint(mock_brain_name, 2000) |
|
|
|
saver2 = TorchSaver(trainer_params, path1, load=True) |
|
|
|
model_saver2 = TorchModelSaver(trainer_params, path1, load=True) |
|
|
|
saver2.register(policy2) |
|
|
|
saver2.initialize_or_load(policy2) |
|
|
|
model_saver2.register(policy2) |
|
|
|
model_saver2.initialize_or_load(policy2) |
|
|
|
saver3 = TorchSaver(trainer_params, path2) |
|
|
|
model_saver3 = TorchModelSaver(trainer_params, path2) |
|
|
|
saver3.register(policy3) |
|
|
|
saver3.initialize_or_load(policy3) |
|
|
|
model_saver3.register(policy3) |
|
|
|
model_saver3.initialize_or_load(policy3) |
|
|
|
_compare_two_policies(policy2, policy3) |
|
|
|
# Assert that the steps are 0. |
|
|
|
assert policy3.get_current_step() == 0 |
|
|
|
|
|
|
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|
|
|
) |
|
|
|
trainer_params = TrainerSettings() |
|
|
|
saver = TorchSaver(trainer_params, model_path) |
|
|
|
saver.register(policy) |
|
|
|
saver.save_checkpoint("Mock_Brain", 100) |
|
|
|
model_saver = TorchModelSaver(trainer_params, model_path) |
|
|
|
model_saver.register(policy) |
|
|
|
model_saver.save_checkpoint("Mock_Brain", 100) |
|
|
|
assert os.path.isfile(model_path + "/Mock_Brain-100.onnx") |