|
|
|
|
|
|
env.close() |
|
|
|
|
|
|
|
|
|
|
|
# Test with constant pretraining learning rate |
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
@mock.patch("mlagents.envs.environment.UnityEnvironment") |
|
|
|
def test_bcmodule_constant_lr_update(mock_env, trainer_config): |
|
|
|
mock_brain = mb.create_mock_3dball_brain() |
|
|
|
trainer_config["pretraining"]["steps"] = 0 |
|
|
|
env, policy = create_policy_with_bc_mock( |
|
|
|
mock_env, mock_brain, trainer_config, False, "test.demo" |
|
|
|
) |
|
|
|
stats = policy.bc_module.update() |
|
|
|
for _, item in stats.items(): |
|
|
|
assert isinstance(item, np.float32) |
|
|
|
old_learning_rate = policy.bc_module.current_lr |
|
|
|
|
|
|
|
stats = policy.bc_module.update() |
|
|
|
assert old_learning_rate == policy.bc_module.current_lr |
|
|
|
|
|
|
|
|
|
|
|
# Test with RNN |
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|