您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
125 行
4.7 KiB
125 行
4.7 KiB
from unittest import mock
|
|
import pytest
|
|
|
|
import numpy as np
|
|
|
|
from mlagents_envs.environment import UnityEnvironment
|
|
from mlagents_envs.base_env import BatchedStepResult
|
|
from mlagents_envs.exception import UnityEnvironmentException, UnityActionException
|
|
from mlagents_envs.mock_communicator import MockCommunicator
|
|
|
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_handles_bad_filename(get_communicator):
|
|
with pytest.raises(UnityEnvironmentException):
|
|
UnityEnvironment(" ")
|
|
|
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_initialization(mock_communicator, mock_launcher):
|
|
mock_communicator.return_value = MockCommunicator(
|
|
discrete_action=False, visual_inputs=0
|
|
)
|
|
env = UnityEnvironment(" ")
|
|
assert env.get_agent_groups() == ["RealFakeBrain"]
|
|
env.close()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"base_port,file_name,expected",
|
|
[
|
|
# Non-None base port value will always be used
|
|
(6001, "foo.exe", 6001),
|
|
# No port specified and environment specified, so use BASE_ENVIRONMENT_PORT
|
|
(None, "foo.exe", UnityEnvironment.BASE_ENVIRONMENT_PORT),
|
|
# No port specified and no environment, so use DEFAULT_EDITOR_PORT
|
|
(None, None, UnityEnvironment.DEFAULT_EDITOR_PORT),
|
|
],
|
|
)
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_port_defaults(
|
|
mock_communicator, mock_launcher, base_port, file_name, expected
|
|
):
|
|
mock_communicator.return_value = MockCommunicator(
|
|
discrete_action=False, visual_inputs=0
|
|
)
|
|
env = UnityEnvironment(file_name=file_name, worker_id=0, base_port=base_port)
|
|
assert expected == env.port
|
|
|
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_reset(mock_communicator, mock_launcher):
|
|
mock_communicator.return_value = MockCommunicator(
|
|
discrete_action=False, visual_inputs=0
|
|
)
|
|
env = UnityEnvironment(" ")
|
|
spec = env.get_agent_group_spec("RealFakeBrain")
|
|
env.reset()
|
|
batched_step_result = env.get_step_result("RealFakeBrain")
|
|
env.close()
|
|
assert isinstance(batched_step_result, BatchedStepResult)
|
|
assert len(spec.observation_shapes) == len(batched_step_result.obs)
|
|
n_agents = batched_step_result.n_agents()
|
|
for shape, obs in zip(spec.observation_shapes, batched_step_result.obs):
|
|
assert (n_agents,) + shape == obs.shape
|
|
|
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_step(mock_communicator, mock_launcher):
|
|
mock_communicator.return_value = MockCommunicator(
|
|
discrete_action=False, visual_inputs=0
|
|
)
|
|
env = UnityEnvironment(" ")
|
|
spec = env.get_agent_group_spec("RealFakeBrain")
|
|
env.step()
|
|
batched_step_result = env.get_step_result("RealFakeBrain")
|
|
n_agents = batched_step_result.n_agents()
|
|
env.set_actions(
|
|
"RealFakeBrain", np.zeros((n_agents, spec.action_size), dtype=np.float32)
|
|
)
|
|
env.step()
|
|
with pytest.raises(UnityActionException):
|
|
env.set_actions(
|
|
"RealFakeBrain",
|
|
np.zeros((n_agents - 1, spec.action_size), dtype=np.float32),
|
|
)
|
|
batched_step_result = env.get_step_result("RealFakeBrain")
|
|
n_agents = batched_step_result.n_agents()
|
|
env.set_actions(
|
|
"RealFakeBrain", -1 * np.ones((n_agents, spec.action_size), dtype=np.float32)
|
|
)
|
|
env.step()
|
|
|
|
env.close()
|
|
assert isinstance(batched_step_result, BatchedStepResult)
|
|
assert len(spec.observation_shapes) == len(batched_step_result.obs)
|
|
for shape, obs in zip(spec.observation_shapes, batched_step_result.obs):
|
|
assert (n_agents,) + shape == obs.shape
|
|
assert not batched_step_result.done[0]
|
|
assert batched_step_result.done[2]
|
|
|
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_close(mock_communicator, mock_launcher):
|
|
comm = MockCommunicator(discrete_action=False, visual_inputs=0)
|
|
mock_communicator.return_value = comm
|
|
env = UnityEnvironment(" ")
|
|
assert env._loaded
|
|
env.close()
|
|
assert not env._loaded
|
|
assert comm.has_been_closed
|
|
|
|
|
|
def test_returncode_to_signal_name():
|
|
assert UnityEnvironment.returncode_to_signal_name(-2) == "SIGINT"
|
|
assert UnityEnvironment.returncode_to_signal_name(42) is None
|
|
assert UnityEnvironment.returncode_to_signal_name("SIGINT") is None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main()
|