您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
175 行
6.4 KiB
175 行
6.4 KiB
from unittest import mock
|
|
import pytest
|
|
|
|
import numpy as np
|
|
|
|
from mlagents_envs.environment import UnityEnvironment
|
|
from mlagents_envs.base_env import DecisionSteps, TerminalSteps
|
|
from mlagents_envs.exception import UnityEnvironmentException, UnityActionException
|
|
from mlagents_envs.mock_communicator import MockCommunicator
|
|
|
|
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_handles_bad_filename(get_communicator):
|
|
with pytest.raises(UnityEnvironmentException):
|
|
UnityEnvironment(" ")
|
|
|
|
|
|
@mock.patch("mlagents_envs.env_utils.launch_executable")
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_initialization(mock_communicator, mock_launcher):
|
|
mock_communicator.return_value = MockCommunicator(
|
|
discrete_action=False, visual_inputs=0
|
|
)
|
|
env = UnityEnvironment(" ")
|
|
assert list(env.behavior_specs.keys()) == ["RealFakeBrain"]
|
|
env.close()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"base_port,file_name,expected",
|
|
[
|
|
# Non-None base port value will always be used
|
|
(6001, "foo.exe", 6001),
|
|
# No port specified and environment specified, so use BASE_ENVIRONMENT_PORT
|
|
(None, "foo.exe", UnityEnvironment.BASE_ENVIRONMENT_PORT),
|
|
# No port specified and no environment, so use DEFAULT_EDITOR_PORT
|
|
(None, None, UnityEnvironment.DEFAULT_EDITOR_PORT),
|
|
],
|
|
)
|
|
@mock.patch("mlagents_envs.env_utils.launch_executable")
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_port_defaults(
|
|
mock_communicator, mock_launcher, base_port, file_name, expected
|
|
):
|
|
mock_communicator.return_value = MockCommunicator(
|
|
discrete_action=False, visual_inputs=0
|
|
)
|
|
env = UnityEnvironment(file_name=file_name, worker_id=0, base_port=base_port)
|
|
assert expected == env.port
|
|
|
|
|
|
@mock.patch("mlagents_envs.env_utils.launch_executable")
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_log_file_path_is_set(mock_communicator, mock_launcher):
|
|
mock_communicator.return_value = MockCommunicator()
|
|
env = UnityEnvironment(
|
|
file_name="myfile", worker_id=0, log_folder="./some-log-folder-path"
|
|
)
|
|
args = env.executable_args()
|
|
log_file_index = args.index("-logFile")
|
|
assert args[log_file_index + 1] == "./some-log-folder-path/Player-0.log"
|
|
|
|
|
|
@mock.patch("mlagents_envs.env_utils.launch_executable")
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_reset(mock_communicator, mock_launcher):
|
|
mock_communicator.return_value = MockCommunicator(
|
|
discrete_action=False, visual_inputs=0
|
|
)
|
|
env = UnityEnvironment(" ")
|
|
spec = env.behavior_specs["RealFakeBrain"]
|
|
env.reset()
|
|
decision_steps, terminal_steps = env.get_steps("RealFakeBrain")
|
|
env.close()
|
|
assert isinstance(decision_steps, DecisionSteps)
|
|
assert isinstance(terminal_steps, TerminalSteps)
|
|
assert len(spec.observation_shapes) == len(decision_steps.obs)
|
|
assert len(spec.observation_shapes) == len(terminal_steps.obs)
|
|
n_agents = len(decision_steps)
|
|
for shape, obs in zip(spec.observation_shapes, decision_steps.obs):
|
|
assert (n_agents,) + shape == obs.shape
|
|
n_agents = len(terminal_steps)
|
|
for shape, obs in zip(spec.observation_shapes, terminal_steps.obs):
|
|
assert (n_agents,) + shape == obs.shape
|
|
|
|
|
|
@mock.patch("mlagents_envs.env_utils.launch_executable")
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_step(mock_communicator, mock_launcher):
|
|
mock_communicator.return_value = MockCommunicator(
|
|
discrete_action=False, visual_inputs=0
|
|
)
|
|
env = UnityEnvironment(" ")
|
|
spec = env.behavior_specs["RealFakeBrain"]
|
|
env.step()
|
|
decision_steps, terminal_steps = env.get_steps("RealFakeBrain")
|
|
n_agents = len(decision_steps)
|
|
env.set_actions(
|
|
"RealFakeBrain", np.zeros((n_agents, spec.action_size), dtype=np.float32)
|
|
)
|
|
env.step()
|
|
with pytest.raises(UnityActionException):
|
|
env.set_actions(
|
|
"RealFakeBrain",
|
|
np.zeros((n_agents - 1, spec.action_size), dtype=np.float32),
|
|
)
|
|
decision_steps, terminal_steps = env.get_steps("RealFakeBrain")
|
|
n_agents = len(decision_steps)
|
|
env.set_actions(
|
|
"RealFakeBrain", -1 * np.ones((n_agents, spec.action_size), dtype=np.float32)
|
|
)
|
|
env.step()
|
|
|
|
env.close()
|
|
assert isinstance(decision_steps, DecisionSteps)
|
|
assert isinstance(terminal_steps, TerminalSteps)
|
|
assert len(spec.observation_shapes) == len(decision_steps.obs)
|
|
assert len(spec.observation_shapes) == len(terminal_steps.obs)
|
|
for shape, obs in zip(spec.observation_shapes, decision_steps.obs):
|
|
assert (n_agents,) + shape == obs.shape
|
|
assert 0 in decision_steps
|
|
assert 2 in terminal_steps
|
|
|
|
|
|
@mock.patch("mlagents_envs.env_utils.launch_executable")
|
|
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
|
|
def test_close(mock_communicator, mock_launcher):
|
|
comm = MockCommunicator(discrete_action=False, visual_inputs=0)
|
|
mock_communicator.return_value = comm
|
|
env = UnityEnvironment(" ")
|
|
assert env._loaded
|
|
env.close()
|
|
assert not env._loaded
|
|
assert comm.has_been_closed
|
|
|
|
|
|
def test_check_communication_compatibility():
|
|
unity_ver = "1.0.0"
|
|
python_ver = "1.0.0"
|
|
unity_package_version = "0.15.0"
|
|
assert UnityEnvironment.check_communication_compatibility(
|
|
unity_ver, python_ver, unity_package_version
|
|
)
|
|
unity_ver = "1.1.0"
|
|
assert UnityEnvironment.check_communication_compatibility(
|
|
unity_ver, python_ver, unity_package_version
|
|
)
|
|
unity_ver = "2.0.0"
|
|
assert not UnityEnvironment.check_communication_compatibility(
|
|
unity_ver, python_ver, unity_package_version
|
|
)
|
|
|
|
unity_ver = "0.16.0"
|
|
python_ver = "0.16.0"
|
|
assert UnityEnvironment.check_communication_compatibility(
|
|
unity_ver, python_ver, unity_package_version
|
|
)
|
|
unity_ver = "0.17.0"
|
|
assert not UnityEnvironment.check_communication_compatibility(
|
|
unity_ver, python_ver, unity_package_version
|
|
)
|
|
unity_ver = "1.16.0"
|
|
assert not UnityEnvironment.check_communication_compatibility(
|
|
unity_ver, python_ver, unity_package_version
|
|
)
|
|
|
|
|
|
def test_returncode_to_signal_name():
|
|
assert UnityEnvironment.returncode_to_signal_name(-2) == "SIGINT"
|
|
assert UnityEnvironment.returncode_to_signal_name(42) is None
|
|
assert UnityEnvironment.returncode_to_signal_name("SIGINT") is None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main()
|