|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def create_mock_3dball_brain(): |
|
|
|
mock_brain = mb.create_mock_brainparams( |
|
|
|
vector_action_space_type="continuous", |
|
|
|
vector_action_space_size=[2], |
|
|
|
vector_observation_space_size=8, |
|
|
|
) |
|
|
|
return mock_brain |
|
|
|
|
|
|
|
|
|
|
|
def create_mock_banana_brain(): |
|
|
|
mock_brain = mb.create_mock_brainparams( |
|
|
|
number_visual_observations=1, |
|
|
|
vector_action_space_type="discrete", |
|
|
|
vector_action_space_size=[3, 3, 3, 2], |
|
|
|
vector_observation_space_size=0, |
|
|
|
) |
|
|
|
return mock_brain |
|
|
|
|
|
|
|
|
|
|
|
def create_ppo_policy_with_bc_mock( |
|
|
|
mock_env, mock_brain, dummy_config, use_rnn, demo_file |
|
|
|
): |
|
|
|
|
|
|
@mock.patch("mlagents.envs.UnityEnvironment") |
|
|
|
def test_bcmodule_defaults(mock_env, dummy_config): |
|
|
|
# See if default values match |
|
|
|
mock_brain = create_mock_3dball_brain() |
|
|
|
mock_brain = mb.create_mock_3dball_brain() |
|
|
|
env, policy = create_ppo_policy_with_bc_mock( |
|
|
|
mock_env, mock_brain, dummy_config, False, "test.demo" |
|
|
|
) |
|
|
|
|
|
|
# Test with continuous control env and vector actions |
|
|
|
@mock.patch("mlagents.envs.UnityEnvironment") |
|
|
|
def test_bcmodule_update(mock_env, dummy_config): |
|
|
|
mock_brain = create_mock_3dball_brain() |
|
|
|
mock_brain = mb.create_mock_3dball_brain() |
|
|
|
env, policy = create_ppo_policy_with_bc_mock( |
|
|
|
mock_env, mock_brain, dummy_config, False, "test.demo" |
|
|
|
) |
|
|
|
|
|
|
# Test with RNN |
|
|
|
@mock.patch("mlagents.envs.UnityEnvironment") |
|
|
|
def test_bcmodule_rnn_update(mock_env, dummy_config): |
|
|
|
mock_brain = create_mock_3dball_brain() |
|
|
|
mock_brain = mb.create_mock_3dball_brain() |
|
|
|
env, policy = create_ppo_policy_with_bc_mock( |
|
|
|
mock_env, mock_brain, dummy_config, True, "test.demo" |
|
|
|
) |
|
|
|
|
|
|
# Test with discrete control and visual observations |
|
|
|
@mock.patch("mlagents.envs.UnityEnvironment") |
|
|
|
def test_bcmodule_dc_visual_update(mock_env, dummy_config): |
|
|
|
mock_brain = create_mock_banana_brain() |
|
|
|
mock_brain = mb.create_mock_banana_brain() |
|
|
|
env, policy = create_ppo_policy_with_bc_mock( |
|
|
|
mock_env, mock_brain, dummy_config, False, "testdcvis.demo" |
|
|
|
) |
|
|
|
|
|
|
# Test with discrete control, visual observations and RNN |
|
|
|
@mock.patch("mlagents.envs.UnityEnvironment") |
|
|
|
def test_bcmodule_rnn_dc_update(mock_env, dummy_config): |
|
|
|
mock_brain = create_mock_banana_brain() |
|
|
|
mock_brain = mb.create_mock_banana_brain() |
|
|
|
env, policy = create_ppo_policy_with_bc_mock( |
|
|
|
mock_env, mock_brain, dummy_config, True, "testdcvis.demo" |
|
|
|
) |
|
|
|