浏览代码

Fixing issue raised in #4393 (#4438)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
bae45836
共有 2 个文件被更改,包括 22 次插入1 次删除
  1. 2
      gym-unity/gym_unity/envs/__init__.py
  2. 21
      gym-unity/gym_unity/tests/test_gym.py

2
gym-unity/gym_unity/envs/__init__.py


# Set action spaces
if self.group_spec.is_action_discrete():
branches = self.group_spec.discrete_action_branches
if self.group_spec.action_shape == 1:
if self.group_spec.action_size == 1:
self._action_space = spaces.Discrete(branches[0])
else:
if flatten_branched:

21
gym-unity/gym_unity/tests/test_gym.py


assert isinstance(env.action_space, spaces.MultiDiscrete)
def test_action_space():
mock_env = mock.MagicMock()
mock_spec = create_mock_group_spec(
vector_action_space_type="discrete", vector_action_space_size=[5]
)
mock_decision_step, mock_terminal_step = create_mock_vector_steps(
mock_spec, num_agents=1
)
setup_mock_unityenvironment(
mock_env, mock_spec, mock_decision_step, mock_terminal_step
)
env = UnityToGymWrapper(mock_env, flatten_branched=True)
assert isinstance(env.action_space, spaces.Discrete)
assert env.action_space.n == 5
env = UnityToGymWrapper(mock_env, flatten_branched=False)
assert isinstance(env.action_space, spaces.Discrete)
assert env.action_space.n == 5
@pytest.mark.parametrize("use_uint8", [True, False], ids=["float", "uint8"])
def test_gym_wrapper_visual(use_uint8):
mock_env = mock.MagicMock()

正在加载...
取消
保存