浏览代码

Fix bug where constant LR in pretraining will throw TF error (#2977)

/develop/tanhsquash
GitHub 5 年前
当前提交
2c7e6d51
共有 2 个文件被更改,包括 21 次插入1 次删除
  1. 2
      ml-agents/mlagents/trainers/components/bc/model.py
  2. 20
      ml-agents/mlagents/trainers/tests/test_bcmodule.py

2
ml-agents/mlagents/trainers/components/bc/model.py


power=1.0,
)
else:
self.annealed_learning_rate = learning_rate
self.annealed_learning_rate = tf.Variable(learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate=self.annealed_learning_rate)
self.update_batch = optimizer.minimize(self.loss)

20
ml-agents/mlagents/trainers/tests/test_bcmodule.py


env.close()
# Test with constant pretraining learning rate
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
@mock.patch("mlagents.envs.environment.UnityEnvironment")
def test_bcmodule_constant_lr_update(mock_env, trainer_config):
mock_brain = mb.create_mock_3dball_brain()
trainer_config["pretraining"]["steps"] = 0
env, policy = create_policy_with_bc_mock(
mock_env, mock_brain, trainer_config, False, "test.demo"
)
stats = policy.bc_module.update()
for _, item in stats.items():
assert isinstance(item, np.float32)
old_learning_rate = policy.bc_module.current_lr
stats = policy.bc_module.update()
assert old_learning_rate == policy.bc_module.current_lr
# Test with RNN
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]

正在加载...
取消
保存