|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
@mock.patch("mlagents.envs.environment.UnityEnvironment") |
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment") |
|
|
|
def test_gail_cc(mock_env, trainer_config, gail_dummy_config): |
|
|
|
env, policy = create_policy_mock( |
|
|
|
mock_env, trainer_config, gail_dummy_config, False, False, False |
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
@mock.patch("mlagents.envs.environment.UnityEnvironment") |
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment") |
|
|
|
def test_gail_dc_visual(mock_env, trainer_config, gail_dummy_config): |
|
|
|
gail_dummy_config["gail"]["demo_path"] = ( |
|
|
|
os.path.dirname(os.path.abspath(__file__)) + "/testdcvis.demo" |
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
@mock.patch("mlagents.envs.environment.UnityEnvironment") |
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment") |
|
|
|
def test_gail_rnn(mock_env, trainer_config, gail_dummy_config): |
|
|
|
env, policy = create_policy_mock( |
|
|
|
mock_env, trainer_config, gail_dummy_config, True, False, False |
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
@mock.patch("mlagents.envs.environment.UnityEnvironment") |
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment") |
|
|
|
def test_curiosity_cc(mock_env, trainer_config, curiosity_dummy_config): |
|
|
|
env, policy = create_policy_mock( |
|
|
|
mock_env, trainer_config, curiosity_dummy_config, False, False, False |
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
@mock.patch("mlagents.envs.environment.UnityEnvironment") |
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment") |
|
|
|
def test_curiosity_dc(mock_env, trainer_config, curiosity_dummy_config): |
|
|
|
env, policy = create_policy_mock( |
|
|
|
mock_env, trainer_config, curiosity_dummy_config, False, True, False |
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
@mock.patch("mlagents.envs.environment.UnityEnvironment") |
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment") |
|
|
|
def test_curiosity_visual(mock_env, trainer_config, curiosity_dummy_config): |
|
|
|
env, policy = create_policy_mock( |
|
|
|
mock_env, trainer_config, curiosity_dummy_config, False, False, True |
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
@mock.patch("mlagents.envs.environment.UnityEnvironment") |
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment") |
|
|
|
def test_curiosity_rnn(mock_env, trainer_config, curiosity_dummy_config): |
|
|
|
env, policy = create_policy_mock( |
|
|
|
mock_env, trainer_config, curiosity_dummy_config, True, False, False |
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
|
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"] |
|
|
|
) |
|
|
|
@mock.patch("mlagents.envs.environment.UnityEnvironment") |
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment") |
|
|
|
def test_extrinsic(mock_env, trainer_config, curiosity_dummy_config): |
|
|
|
env, policy = create_policy_mock( |
|
|
|
mock_env, trainer_config, curiosity_dummy_config, False, False, False |
|
|
|