浏览代码

fixed test_sac

/internal-policy-ghost
Andrew Cohen 5 年前
当前提交
0af2a651
共有 2 个文件被更改,包括 11 次插入18 次删除
  1. 1
      ml-agents/mlagents/trainers/tests/test_ppo.py
  2. 28
      ml-agents/mlagents/trainers/tests/test_sac.py

1
ml-agents/mlagents/trainers/tests/test_ppo.py


dummy_config["summary_path"] = "./summaries/test_trainer_summary"
dummy_config["model_path"] = "./models/test_trainer_models/TestModel"
# nn_policy.get_current_step.return_value = 2000
trainer = PPOTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
trainer.add_policy(brain_params.brain_name, brain_params)

28
ml-agents/mlagents/trainers/tests/test_sac.py


trainer_params["model_path"] = str(tmpdir)
trainer_params["save_replay_buffer"] = True
trainer = SACTrainer(mock_brain.brain_name, 1, trainer_params, True, False, 0, 0)
policy = trainer.create_policy(mock_brain)
trainer.add_policy(mock_brain.brain_name, policy)
trainer.add_policy(mock_brain.brain_name, mock_brain)
policy = trainer.get_policy(mock_brain.brain_name)
trainer.update_buffer = mb.simulate_rollout(BUFFER_INIT_SAMPLES, policy.brain)
buffer_len = trainer.update_buffer.num_experiences

trainer2 = SACTrainer(mock_brain.brain_name, 1, trainer_params, True, True, 0, 0)
policy = trainer2.create_policy(mock_brain)
trainer2.add_policy(mock_brain.brain_name, policy)
trainer2.add_policy(mock_brain.brain_name, mock_brain)
@mock.patch("mlagents.trainers.sac.trainer.NNPolicy")
def test_add_get_policy(sac_optimizer, dummy_config):
def test_add_get_policy(sac_optimizer, nn_policy, dummy_config):
brain_params = make_brain_parameters(
discrete_action=False, visual_inputs=0, vec_obs_size=6
)

mock_policy = mock.Mock()
mock_policy.get_current_step = mock.Mock(return_value=2000)
nn_policy.return_value = mock_policy
trainer = SACTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
policy = mock.Mock(spec=NNPolicy)
policy.get_current_step.return_value = 2000
trainer.add_policy(brain_params.brain_name, policy)
assert trainer.get_policy(brain_params.brain_name) == policy
trainer = SACTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
trainer.add_policy(brain_params.brain_name, brain_params)
# Test incorrect class of policy
policy = mock.Mock()
with pytest.raises(RuntimeError):
trainer.add_policy(brain_params, policy)
def test_process_trajectory(dummy_config):
brain_params = make_brain_parameters(

dummy_config["model_path"] = "./models/test_trainer_models/TestModel"
trainer = SACTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
policy = trainer.create_policy(brain_params)
trainer.add_policy(brain_params.brain_name, policy)
trainer.add_policy(brain_params.brain_name, brain_params)
trajectory_queue = AgentManagerQueue("testbrain")
trainer.subscribe_trajectory_queue(trajectory_queue)

正在加载...
取消
保存