浏览代码

Adding unit tests

/develop/gym-wrapper
vincentpierre 5 年前
当前提交
d034f387
共有 3 个文件被更改,包括 43 次插入1 次删除
  1. 5
      ml-agents-envs/mlagents_envs/gym_wrapper.py
  2. 1
      ml-agents-envs/setup.py
  3. 38
      ml-agents-envs/mlagents_envs/tests/test_gym_wrapper.py

5
ml-agents-envs/mlagents_envs/gym_wrapper.py


assert behavior_name == self._behavior_name
spec = self._behavior_specs
expected_type = np.float32 if spec.is_action_continuous() else np.int32
expected_shape = (1, spec.action_size)
n_agents = len(self._current_steps[0])
expected_shape = (n_agents, spec.action_size)
if action.shape != expected_shape:
raise UnityActionException(
"The behavior {0} needs an input of dimension {1} but received input of dimension {2}".format(

if action.dtype != expected_type:
action = action.astype(expected_type)
if n_agents == 0:
return
if isinstance(self._gym_env.action_space, gym.spaces.Discrete):
self._g_action = int(action[0, 0])
elif isinstance(self._gym_env.action_space, gym.spaces.Box):

1
ml-agents-envs/setup.py


zip_safe=False,
install_requires=[
"gym",
"gym[atari]",
"cloudpickle",
"grpcio>=1.11.0",
"numpy>=1.14.1,<2.0",

38
ml-agents-envs/mlagents_envs/tests/test_gym_wrapper.py


from mlagents_envs.gym_wrapper import GymWrapper
from mlagents_envs.base_env import ActionType
import gym
import pytest
GYM_ENVS = ["CartPole-v1", "MountainCar-v0", "AirRaid-v0"]
@pytest.mark.parametrize("name", GYM_ENVS, ids=GYM_ENVS)
def test_creation(name):
env = GymWrapper(gym.make(name), name)
env.close()
@pytest.mark.parametrize("name", GYM_ENVS, ids=GYM_ENVS)
def test_specs(name):
gym_env = gym.make(name)
env = GymWrapper(gym_env, name)
assert env.get_behavior_names()[0] == name
if isinstance(gym_env.action_space, gym.spaces.Box):
assert env.get_behavior_spec(name).action_type == ActionType.CONTINUOUS
elif isinstance(gym_env.action_space, gym.spaces.Discrete):
assert env.get_behavior_spec(name).action_type == ActionType.DISCRETE
else:
raise NotImplementedError("Test for this action space type not implemented")
env.close()
@pytest.mark.parametrize("name", GYM_ENVS, ids=GYM_ENVS)
def test_steps(name):
env = GymWrapper(gym.make(name), name)
spec = env.get_behavior_spec(name)
env.reset()
for _ in range(200):
d_steps, t_steps = env.get_steps(name)
env.set_actions(name, spec.create_empty_action(len(d_steps)))
env.step()
env.close()
正在加载...
取消
保存