|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
from mlagents.tf_utils import tf |
|
|
|
from mlagents.trainers.saver.tf_saver import TFSaver |
|
|
|
from mlagents.trainers.model_saver.tf_model_saver import TFModelSaver |
|
|
|
from mlagents.trainers import __version__ |
|
|
|
from mlagents.trainers.settings import TrainerSettings |
|
|
|
from mlagents.trainers.policy.tf_policy import TFPolicy |
|
|
|
|
|
|
|
|
|
|
def test_register(tmp_path): |
|
|
|
trainer_params = TrainerSettings() |
|
|
|
saver = TFSaver(trainer_params, tmp_path) |
|
|
|
model_saver = TFModelSaver(trainer_params, tmp_path) |
|
|
|
saver.register(opt) |
|
|
|
assert saver.policy is None |
|
|
|
model_saver.register(opt) |
|
|
|
assert model_saver.policy is None |
|
|
|
saver.register(policy) |
|
|
|
assert saver.policy is not None |
|
|
|
model_saver.register(policy) |
|
|
|
assert model_saver.policy is not None |
|
|
|
|
|
|
|
|
|
|
|
class ModelVersionTest(unittest.TestCase): |
|
|
|
|
|
|
trainer_params = TrainerSettings() |
|
|
|
mock_path = tempfile.mkdtemp() |
|
|
|
policy = create_policy_mock(trainer_params) |
|
|
|
saver = TFSaver(trainer_params, mock_path) |
|
|
|
saver.register(policy) |
|
|
|
model_saver = TFModelSaver(trainer_params, mock_path) |
|
|
|
model_saver.register(policy) |
|
|
|
saver._check_model_version( |
|
|
|
model_saver._check_model_version( |
|
|
|
saver._check_model_version(__version__) # This should be the right version |
|
|
|
model_saver._check_model_version( |
|
|
|
__version__ |
|
|
|
) # This should be the right version |
|
|
|
# Assert that no additional warnings have been thrown wth correct ver |
|
|
|
assert len(cm.output) == 1 |
|
|
|
|
|
|
|
|
|
|
path2 = os.path.join(tmp_path, "runid2") |
|
|
|
trainer_params = TrainerSettings() |
|
|
|
policy = create_policy_mock(trainer_params) |
|
|
|
saver = TFSaver(trainer_params, path1) |
|
|
|
saver.register(policy) |
|
|
|
saver.initialize_or_load(policy) |
|
|
|
model_saver = TFModelSaver(trainer_params, path1) |
|
|
|
model_saver.register(policy) |
|
|
|
model_saver.initialize_or_load(policy) |
|
|
|
saver.save_checkpoint(mock_brain_name, 2000) |
|
|
|
model_saver.save_checkpoint(mock_brain_name, 2000) |
|
|
|
saver = TFSaver(trainer_params, path1, load=True) |
|
|
|
model_saver = TFModelSaver(trainer_params, path1, load=True) |
|
|
|
saver.register(policy2) |
|
|
|
saver.initialize_or_load(policy2) |
|
|
|
model_saver.register(policy2) |
|
|
|
model_saver.initialize_or_load(policy2) |
|
|
|
saver = TFSaver(trainer_params, path2) |
|
|
|
model_saver = TFModelSaver(trainer_params, path2) |
|
|
|
saver.register(policy3) |
|
|
|
saver.initialize_or_load(policy3) |
|
|
|
model_saver.register(policy3) |
|
|
|
model_saver.initialize_or_load(policy3) |
|
|
|
|
|
|
|
_compare_two_policies(policy2, policy3) |
|
|
|
# Assert that the steps are 0. |
|
|
|
|
|
|
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|
|
|
) |
|
|
|
trainer_params = TrainerSettings() |
|
|
|
saver = TFSaver(trainer_params, model_path) |
|
|
|
saver.register(policy) |
|
|
|
saver.save_checkpoint("Mock_Brain", 100) |
|
|
|
model_saver = TFModelSaver(trainer_params, model_path) |
|
|
|
model_saver.register(policy) |
|
|
|
model_saver.save_checkpoint("Mock_Brain", 100) |
|
|
|
assert os.path.isfile(model_path + "/Mock_Brain-100.nn") |