|
|
|
|
|
|
|
|
|
|
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_policy_outputs") |
|
|
|
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.add_rewards_outputs") |
|
|
|
@pytest.mark.parametrize("num_vis_obs", [0, 1], ids=["", "viz"]) |
|
|
|
@pytest.mark.parametrize("num_vis_obs", [0, 1, 2], ids=["vec", "1 viz", "2 viz"]) |
|
|
|
def test_rl_trainer(add_policy_outputs, add_rewards_outputs, num_vis_obs): |
|
|
|
trainer = create_rl_trainer() |
|
|
|
trainer.policy = create_mock_policy() |
|
|
|