浏览代码

update test_saver_reward_providers.py

/MLA-1734-demo-provider
Ruo-Ping Dong 4 年前
当前提交
56feb8af
共有 1 个文件被更改,包括 10 次插入10 次删除
  1. 20
      ml-agents/mlagents/trainers/tests/torch/saver/test_saver_reward_providers.py

20
ml-agents/mlagents/trainers/tests/torch/saver/test_saver_reward_providers.py


from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
from mlagents.trainers.saver.torch_saver import TorchSaver
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
from mlagents.trainers.settings import (
TrainerSettings,
RewardSignalType,

# save at path 1
path1 = os.path.join(tmp_path, "runid1")
saver = TorchSaver(trainer_settings, path1)
saver.register(policy)
saver.register(optimizer)
saver.initialize_or_load()
model_saver = TorchModelSaver(trainer_settings, path1)
model_saver.register(policy)
model_saver.register(optimizer)
model_saver.initialize_or_load()
saver.save_checkpoint("MockBrain", 2000)
model_saver.save_checkpoint("MockBrain", 2000)
# create a new optimizer and policy
optimizer2 = OptimizerClass(policy, trainer_settings)

saver2 = TorchSaver(trainer_settings, path1, load=True)
saver2.register(policy2)
saver2.register(optimizer2)
saver2.initialize_or_load() # This is to load the optimizers
model_saver2 = TorchModelSaver(trainer_settings, path1, load=True)
model_saver2.register(policy2)
model_saver2.register(optimizer2)
model_saver2.initialize_or_load() # This is to load the optimizers
# assert the models have the same weights
module_dict_1 = optimizer.get_modules()

正在加载...
取消
保存