浏览代码

fixed sac recurrent tf simple rl

/develop/actionmodel-csharp
Andrew Cohen 4 年前
当前提交
ff324d0c
共有 2 个文件被更改,包括 2 次插入2 次删除
  1. 2
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  2. 2
      ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py

2
ml-agents/mlagents/trainers/sac/optimizer_torch.py


]
policy_params = list(self.policy.actor_critic.network_body.parameters()) + list(
self.policy.actor_critic.distribution.parameters()
self.policy.actor_critic.action_model.parameters()
)
value_params = list(self.value_network.parameters()) + list(
self.policy.actor_critic.critic.parameters()

2
ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py


@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)])
def test_recurrent_sac(action_sizes):
step_size = 0.2 if action_sizes else 0.5
step_size = 0.2 if action_sizes == (0, 1) else 0.5
env = MemoryEnvironment(
[BRAIN_NAME], action_sizes=action_sizes, step_size=step_size
)

正在加载...
取消
保存