浏览代码

fixed torch test sac

/develop/actionmodel-csharp
Andrew Cohen 4 年前
当前提交
7af25330
共有 3 个文件被更改,包括 5 次插入4 次删除
  1. 4
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  2. 3
      ml-agents/mlagents/trainers/tests/torch/test_sac.py
  3. 2
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py

4
ml-agents/mlagents/trainers/sac/optimizer_torch.py


cont_sampled_actions = sampled_actions.continuous_tensor
cont_actions = actions.continuous_tensor
disc_actions = actions.discrete_tensor
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,

sequence_length=self.policy.sequence_length,
)
if self._action_spec.discrete_size:
if self._action_spec.discrete_size > 0:
disc_actions = actions.discrete_tensor
q1_stream = self._condense_q_streams(q1_out, disc_actions)
q2_stream = self._condense_q_streams(q2_out, disc_actions)
else:

3
ml-agents/mlagents/trainers/tests/torch/test_sac.py


"Losses/Value Loss",
"Losses/Q1 Loss",
"Losses/Q2 Loss",
"Policy/Entropy Coeff",
"Policy/Continuous Entropy Coeff",
"Policy/Discrete Entropy Coeff",
"Policy/Learning Rate",
]
for stat in required_stats:

2
ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py


@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)])
def test_recurrent_sac(action_sizes):
step_size = 0.2 if action_sizes else 0.5
step_size = 0.2 if action_sizes == (0, 1) else 0.5
env = MemoryEnvironment(
[BRAIN_NAME], action_sizes=action_sizes, step_size=step_size
)

正在加载...
取消
保存