|
|
|
|
|
|
|
|
|
|
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_policy_outputs") |
|
|
|
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_rewards_outputs") |
|
|
|
def test_rl_trainer(add_policy_outputs, add_rewards_outputs): |
|
|
|
@pytest.mark.parametrize("num_vis_obs", [0, 1], ids=["", "viz"]) |
|
|
|
def test_rl_trainer(add_policy_outputs, add_rewards_outputs, num_vis_obs): |
|
|
|
trainer = create_rl_trainer() |
|
|
|
trainer.policy = create_mock_policy() |
|
|
|
fake_action_outputs = { |
|
|
|
|
|
|
num_agents=2, |
|
|
|
num_vector_observations=8, |
|
|
|
num_vector_acts=2, |
|
|
|
num_vis_observations=1, |
|
|
|
num_vis_observations=num_vis_obs, |
|
|
|
) |
|
|
|
trainer.add_experiences( |
|
|
|
create_mock_all_brain_info(mock_braininfo), |
|
|
|
|
|
|
num_agents=1, |
|
|
|
num_vector_observations=8, |
|
|
|
num_vector_acts=2, |
|
|
|
num_vis_observations=1, |
|
|
|
num_vis_observations=num_vis_obs, |
|
|
|
assert len(brain_info.visual_observations) == 1 |
|
|
|
assert len(brain_info.visual_observations[0]) == 1 |
|
|
|
assert len(brain_info.visual_observations) == num_vis_obs |
|
|
|
assert len(brain_info.vector_observations) == 1 |
|
|
|
assert len(brain_info.previous_vector_actions) == 1 |
|
|
|
|
|
|
|