浏览代码

Remove some unneeded stuff

/develop/add-fire/policy-tests
Ervin Teng 4 年前
当前提交
020ce8ad
共有 1 个文件被更改,包括 4 次插入4 次删除
  1. 8
      ml-agents/mlagents/trainers/tests/torch/test_policy.py

8
ml-agents/mlagents/trainers/tests/torch/test_policy.py


NetworkSettings.MemorySettings() if use_rnn else None
)
policy = TorchPolicy(seed, mock_spec, trainer_settings)
return policy, trainer_settings
return policy
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])

# Test evaluate
policy, _ = create_policy_mock(
policy = create_policy_mock(
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
decision_step, terminal_step = mb.create_steps_from_behavior_spec(

@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_evaluate_actions(rnn, visual, discrete):
policy, trainer_settings = create_policy_mock(
policy = create_policy_mock(
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size)

@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_sample_actions(rnn, visual, discrete):
policy, trainer_settings = create_policy_mock(
policy = create_policy_mock(
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size)

正在加载...
取消
保存