浏览代码
Adding some tests (#3952)
Adding some tests (#3952)
* Adding some tests * fixing the step test by enforcing the types of np.array * addressing comments/docs-update
GitHub
5 年前
当前提交
3a8f6d0c
共有 4 个文件被更改,包括 176 次插入 和 9 次删除
-
6ml-agents-envs/mlagents_envs/base_env.py
-
13ml-agents-envs/mlagents_envs/env_utils.py
-
64ml-agents-envs/mlagents_envs/tests/test_env_utils.py
-
102ml-agents-envs/mlagents_envs/tests/test_steps.py
|
|||
from unittest import mock |
|||
import pytest |
|||
from mlagents_envs.env_utils import validate_environment_path, launch_executable |
|||
from mlagents_envs.exception import UnityEnvironmentException |
|||
from mlagents_envs.logging_util import ( |
|||
set_log_level, |
|||
get_logger, |
|||
INFO, |
|||
ERROR, |
|||
FATAL, |
|||
CRITICAL, |
|||
DEBUG, |
|||
) |
|||
|
|||
|
|||
def mock_glob_method(path): |
|||
""" |
|||
Given a path input, returns a list of candidates |
|||
""" |
|||
if ".x86" in path: |
|||
return ["linux"] |
|||
if ".app" in path: |
|||
return ["darwin"] |
|||
if ".exe" in path: |
|||
return ["win32"] |
|||
if "*" in path: |
|||
return "Any" |
|||
return [] |
|||
|
|||
|
|||
@mock.patch("sys.platform") |
|||
@mock.patch("glob.glob") |
|||
def test_validate_path_empty(glob_mock, platform_mock): |
|||
glob_mock.return_value = None |
|||
path = validate_environment_path(" ") |
|||
assert path is None |
|||
|
|||
|
|||
@mock.patch("mlagents_envs.env_utils.get_platform") |
|||
@mock.patch("glob.glob") |
|||
def test_validate_path(glob_mock, platform_mock): |
|||
glob_mock.side_effect = mock_glob_method |
|||
for platform in ["linux", "darwin", "win32"]: |
|||
platform_mock.return_value = platform |
|||
path = validate_environment_path(" ") |
|||
assert path == platform |
|||
|
|||
|
|||
@mock.patch("glob.glob") |
|||
@mock.patch("subprocess.Popen") |
|||
def test_launch_executable(mock_popen, glob_mock): |
|||
with pytest.raises(UnityEnvironmentException): |
|||
launch_executable(" ", []) |
|||
glob_mock.return_value = ["FakeLaunchPath"] |
|||
launch_executable(" ", []) |
|||
mock_popen.side_effect = PermissionError("Fake permission error") |
|||
with pytest.raises(UnityEnvironmentException): |
|||
launch_executable(" ", []) |
|||
|
|||
|
|||
def test_set_logging_level(): |
|||
for level in [INFO, ERROR, FATAL, CRITICAL, DEBUG]: |
|||
set_log_level(level) |
|||
assert get_logger("test").level == level |
|
|||
import pytest |
|||
import numpy as np |
|||
|
|||
from mlagents_envs.base_env import ( |
|||
DecisionSteps, |
|||
TerminalSteps, |
|||
ActionType, |
|||
BehaviorSpec, |
|||
) |
|||
|
|||
|
|||
def test_decision_steps(): |
|||
ds = DecisionSteps( |
|||
obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)], |
|||
reward=np.array(range(3), dtype=np.float32), |
|||
agent_id=np.array(range(10, 13), dtype=np.int32), |
|||
action_mask=[np.zeros((3, 4), dtype=np.bool)], |
|||
) |
|||
|
|||
assert ds.agent_id_to_index[10] == 0 |
|||
assert ds.agent_id_to_index[11] == 1 |
|||
assert ds.agent_id_to_index[12] == 2 |
|||
|
|||
with pytest.raises(KeyError): |
|||
assert ds.agent_id_to_index[-1] == -1 |
|||
|
|||
mask_agent = ds[10].action_mask |
|||
assert isinstance(mask_agent, list) |
|||
assert len(mask_agent) == 1 |
|||
assert np.array_equal(mask_agent[0], np.zeros((4), dtype=np.bool)) |
|||
|
|||
for agent_id in ds: |
|||
assert ds.agent_id_to_index[agent_id] in range(3) |
|||
|
|||
|
|||
def test_empty_decision_steps(): |
|||
specs = BehaviorSpec( |
|||
observation_shapes=[(3, 2), (5,)], |
|||
action_type=ActionType.CONTINUOUS, |
|||
action_shape=3, |
|||
) |
|||
ds = DecisionSteps.empty(specs) |
|||
assert len(ds.obs) == 2 |
|||
assert ds.obs[0].shape == (0, 3, 2) |
|||
assert ds.obs[1].shape == (0, 5) |
|||
|
|||
|
|||
def test_terminal_steps(): |
|||
ts = TerminalSteps( |
|||
obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)], |
|||
reward=np.array(range(3), dtype=np.float32), |
|||
agent_id=np.array(range(10, 13), dtype=np.int32), |
|||
interrupted=np.array([1, 0, 1], dtype=np.bool), |
|||
) |
|||
|
|||
assert ts.agent_id_to_index[10] == 0 |
|||
assert ts.agent_id_to_index[11] == 1 |
|||
assert ts.agent_id_to_index[12] == 2 |
|||
|
|||
assert ts[10].interrupted |
|||
assert not ts[11].interrupted |
|||
assert ts[12].interrupted |
|||
|
|||
with pytest.raises(KeyError): |
|||
assert ts.agent_id_to_index[-1] == -1 |
|||
|
|||
for agent_id in ts: |
|||
assert ts.agent_id_to_index[agent_id] in range(3) |
|||
|
|||
|
|||
def test_empty_terminal_steps(): |
|||
specs = BehaviorSpec( |
|||
observation_shapes=[(3, 2), (5,)], |
|||
action_type=ActionType.CONTINUOUS, |
|||
action_shape=3, |
|||
) |
|||
ts = TerminalSteps.empty(specs) |
|||
assert len(ts.obs) == 2 |
|||
assert ts.obs[0].shape == (0, 3, 2) |
|||
assert ts.obs[1].shape == (0, 5) |
|||
|
|||
|
|||
def test_specs(): |
|||
specs = BehaviorSpec( |
|||
observation_shapes=[(3, 2), (5,)], |
|||
action_type=ActionType.CONTINUOUS, |
|||
action_shape=3, |
|||
) |
|||
assert specs.discrete_action_branches is None |
|||
assert specs.action_size == 3 |
|||
assert specs.create_empty_action(5).shape == (5, 3) |
|||
assert specs.create_empty_action(5).dtype == np.float32 |
|||
|
|||
specs = BehaviorSpec( |
|||
observation_shapes=[(3, 2), (5,)], |
|||
action_type=ActionType.DISCRETE, |
|||
action_shape=(3,), |
|||
) |
|||
assert specs.discrete_action_branches == (3,) |
|||
assert specs.action_size == 1 |
|||
assert specs.create_empty_action(5).shape == (5, 1) |
|||
assert specs.create_empty_action(5).dtype == np.int32 |
撰写
预览
正在加载...
取消
保存
Reference in new issue