|
|
|
|
|
|
assert isinstance(env.action_space, spaces.MultiDiscrete) |
|
|
|
|
|
|
|
|
|
|
|
def test_action_space(): |
|
|
|
mock_env = mock.MagicMock() |
|
|
|
mock_spec = create_mock_group_spec( |
|
|
|
vector_action_space_type="discrete", vector_action_space_size=[5] |
|
|
|
) |
|
|
|
mock_decision_step, mock_terminal_step = create_mock_vector_steps( |
|
|
|
mock_spec, num_agents=1 |
|
|
|
) |
|
|
|
setup_mock_unityenvironment( |
|
|
|
mock_env, mock_spec, mock_decision_step, mock_terminal_step |
|
|
|
) |
|
|
|
|
|
|
|
env = UnityToGymWrapper(mock_env, flatten_branched=True) |
|
|
|
assert isinstance(env.action_space, spaces.Discrete) |
|
|
|
assert env.action_space.n == 5 |
|
|
|
|
|
|
|
env = UnityToGymWrapper(mock_env, flatten_branched=False) |
|
|
|
assert isinstance(env.action_space, spaces.Discrete) |
|
|
|
assert env.action_space.n == 5 |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("use_uint8", [True, False], ids=["float", "uint8"]) |
|
|
|
def test_gym_wrapper_visual(use_uint8): |
|
|
|
mock_env = mock.MagicMock() |
|
|
|