浏览代码

fix tf ghost tests

/develop/add-fire/ghost
Andrew Cohen 4 年前
当前提交
39bca7d2
共有 2 个文件被更改,包括 4 次插入7 次删除
  1. 7
      ml-agents/mlagents/trainers/tests/test_ghost.py
  2. 4
      ml-agents/mlagents/trainers/tests/torch/test_ghost.py

7
ml-agents/mlagents/trainers/tests/test_ghost.py


trainer_params = dummy_config
trainer = PPOTrainer("test", 0, trainer_params, True, False, 0, "0")
trainer.seed = 1
policy = trainer.create_policy("test", mock_specs)
policy.create_tf_graph()
policy = trainer.create_policy("test", mock_specs, create_graph=True)
to_load_policy = trainer.create_policy("test", mock_specs)
to_load_policy.create_tf_graph()
to_load_policy.init_load_weights()
to_load_policy = trainer.create_policy("test", mock_specs, create_graph=True)
weights = policy.get_weights()
load_weights = to_load_policy.get_weights()

4
ml-agents/mlagents/trainers/tests/torch/test_ghost.py


trainer_params = dummy_config
trainer = PPOTrainer("test", 0, trainer_params, True, False, 0, "0")
trainer.seed = 1
policy = trainer.create_policy("test", mock_specs, create_graph=True)
policy = trainer.create_policy("test", mock_specs)
to_load_policy = trainer.create_policy("test", mock_specs, create_graph=True)
to_load_policy = trainer.create_policy("test", mock_specs)
weights = policy.get_weights()
load_weights = to_load_policy.get_weights()

正在加载...
取消
保存