|
|
|
|
|
|
NetworkSettings.MemorySettings() if use_rnn else None |
|
|
|
) |
|
|
|
policy = TorchPolicy(seed, mock_spec, trainer_settings) |
|
|
|
return policy, trainer_settings |
|
|
|
return policy |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|
|
|
|
|
|
# Test evaluate |
|
|
|
policy, _ = create_policy_mock( |
|
|
|
policy = create_policy_mock( |
|
|
|
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|
|
|
) |
|
|
|
decision_step, terminal_step = mb.create_steps_from_behavior_spec( |
|
|
|
|
|
|
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|
|
|
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|
|
|
def test_evaluate_actions(rnn, visual, discrete): |
|
|
|
policy, trainer_settings = create_policy_mock( |
|
|
|
policy = create_policy_mock( |
|
|
|
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|
|
|
) |
|
|
|
buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) |
|
|
|
|
|
|
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|
|
|
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|
|
|
def test_sample_actions(rnn, visual, discrete): |
|
|
|
policy, trainer_settings = create_policy_mock( |
|
|
|
policy = create_policy_mock( |
|
|
|
TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|
|
|
) |
|
|
|
buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) |
|
|
|