浏览代码

[MLA-1952] Add optional seed for gym action spaces (#5303)

* add optional seed for action spaces

* add optional seed for action spaces

* changelog

* undo packages-lock.json change
/colab-links
GitHub 4 年前
当前提交
fd8737fd
共有 3 个文件被更改,包括 26 次插入8 次删除
  1. 4
      com.unity.ml-agents/CHANGELOG.md
  2. 9
      gym-unity/gym_unity/envs/__init__.py
  3. 21
      gym-unity/gym_unity/tests/test_gym.py

4
com.unity.ml-agents/CHANGELOG.md


## [Unreleased]
### Major Changes
### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
The `UnityToGymWrapper` initializer now accepts an optional `action_space_seed` seed. If this is specified, it will
be used to set the random seed on the resulting action space. (#5303)
### Bug Fixes

9
gym-unity/gym_unity/envs/__init__.py


import itertools
import numpy as np
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import gym
from gym import error, spaces

uint8_visual: bool = False,
flatten_branched: bool = False,
allow_multiple_obs: bool = False,
action_space_seed: Optional[int] = None,
):
"""
Environment initialization

containing the visual observations and the last element containing the array of vector observations.
If False, returns a single np.ndarray containing either only a single visual observation or the array of
vector observations.
:param action_space_seed: If non-None, will be used to set the random seed on created gym.Space instances.
"""
self._env = unity_env

"The gym wrapper does not provide explicit support for both discrete "
"and continuous actions."
)
if action_space_seed is not None:
self._action_space.seed(action_space_seed)
# Set observations space
list_spaces: List[gym.Space] = []

return -float("inf"), float("inf")
@property
def action_space(self):
def action_space(self) -> gym.Space:
return self._action_space
@property

21
gym-unity/gym_unity/tests/test_gym.py


mock_env, mock_spec, mock_decision_step, mock_terminal_step
)
env = UnityToGymWrapper(mock_env)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.reset(), np.ndarray)
actions = env.action_space.sample()
assert actions.shape[0] == 2

assert env.action_space.n == 5
def test_action_space_seed():
mock_env = mock.MagicMock()
mock_spec = create_mock_group_spec()
mock_decision_step, mock_terminal_step = create_mock_vector_steps(mock_spec)
setup_mock_unityenvironment(
mock_env, mock_spec, mock_decision_step, mock_terminal_step
)
actions = []
for _ in range(0, 2):
env = UnityToGymWrapper(mock_env, action_space_seed=1337)
env.reset()
actions.append(env.action_space.sample())
assert (actions[0] == actions[1]).all()
@pytest.mark.parametrize("use_uint8", [True, False], ids=["float", "uint8"])
def test_gym_wrapper_visual(use_uint8):
mock_env = mock.MagicMock()

env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8)
assert isinstance(env.observation_space, spaces.Box)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.reset(), np.ndarray)
actions = env.action_space.sample()
assert actions.shape[0] == 2

)
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=True)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.observation_space, spaces.Tuple)
assert len(env.observation_space) == 2
reset_obs = env.reset()

# check behavior for allow_multiple_obs = False
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=False)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.observation_space, spaces.Box)
reset_obs = env.reset()
assert isinstance(reset_obs, np.ndarray)

)
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=True)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.observation_space, spaces.Tuple)
assert len(env.observation_space) == 3
reset_obs = env.reset()

# check behavior for allow_multiple_obs = False
env = UnityToGymWrapper(mock_env, uint8_visual=use_uint8, allow_multiple_obs=False)
assert isinstance(env, UnityToGymWrapper)
assert isinstance(env.observation_space, spaces.Box)
reset_obs = env.reset()
assert isinstance(reset_obs, np.ndarray)

正在加载...
取消
保存