浏览代码

Safer import try catch

/develop/gym-wrapper
vincentpierre 4 年前
当前提交
ef4cd9c9
共有 2 个文件被更改,包括 23 次插入4 次删除
  1. 8
      ml-agents-envs/mlagents_envs/gym_wrapper.py
  2. 19
      ml-agents-envs/mlagents_envs/tests/test_gym_wrapper.py

8
ml-agents-envs/mlagents_envs/gym_wrapper.py


try:
import gym
_GYM_IMPORTED = True
raise ImportError("gym is not installed, gym required to use the GymToUnityWrapper")
_GYM_IMPORTED = False
class GymToUnityWrapper(BaseEnv):

def __init__(self, gym_env, name=None):
if not _GYM_IMPORTED:
raise RuntimeError(
"gym is not installed, gym required to use the GymToUnityWrapper"
)
self._gym_env = gym_env
self._first_message = True
self._behavior_name = name

19
ml-agents-envs/mlagents_envs/tests/test_gym_wrapper.py


try:
import gym
_GYM_IMPORTED = True
raise ImportError(
"gym is not installed, please call `pip install gym` before running this test"
)
_GYM_IMPORTED = False
import pytest
GYM_ENVS = ["CartPole-v1", "MountainCar-v0"]

def test_creation(name):
if not _GYM_IMPORTED:
raise RuntimeError(
"gym is not installed, gym required to test the GymToUnityWrapper"
)
env = GymToUnityWrapper(gym.make(name), name)
env.close()

if not _GYM_IMPORTED:
raise RuntimeError(
"gym is not installed, gym required to test the GymToUnityWrapper"
)
gym_env = gym.make(name)
env = GymToUnityWrapper(gym_env, name)
assert env.get_behavior_names()[0] == name

@pytest.mark.parametrize("name", GYM_ENVS, ids=GYM_ENVS)
def test_steps(name):
if not _GYM_IMPORTED:
raise RuntimeError(
"gym is not installed, gym required to test the GymToUnityWrapper"
)
env = GymToUnityWrapper(gym.make(name), name)
spec = env.get_behavior_spec(name)
env.reset()

正在加载...
取消
保存