浏览代码

Making some things private in UnityEnvironment (#3951)

* Making some things private in UnityEnvironment

* Readding the default ports as public

* removing _SCALAR_ACTION_TYPES and _SINGLE_BRAIN_ACTION_TYPES

* Removing unused method
/docs-update
GitHub 4 年前
当前提交
9083752d
共有 3 个文件被更改,包括 68 次插入93 次删除
  1. 119
      ml-agents-envs/mlagents_envs/environment.py
  2. 36
      ml-agents-envs/mlagents_envs/tests/test_envs.py
  3. 6
      ml-agents/mlagents/trainers/learn.py

119
ml-agents-envs/mlagents_envs/environment.py


import numpy as np
import os
import subprocess
from typing import Dict, List, Optional, Any, Tuple, Mapping as MappingType
from typing import Dict, List, Optional, Tuple, Mapping as MappingType
import mlagents_envs

class UnityEnvironment(BaseEnv):
SCALAR_ACTION_TYPES = (int, np.int32, np.int64, float, np.float32, np.float64)
SINGLE_BRAIN_ACTION_TYPES = SCALAR_ACTION_TYPES + (list, np.ndarray)
# Communication protocol version.
# When connecting to C#, this must be compatible with Academy.k_ApiVersion.
# We follow semantic versioning on the communication version, so existing

BASE_ENVIRONMENT_PORT = 5005
# Command line argument used to pass the port to the executable environment.
PORT_COMMAND_LINE_ARG = "--mlagents-port"
_PORT_COMMAND_LINE_ARG = "--mlagents-port"
@staticmethod
def _raise_version_exception(unity_com_ver: str) -> None:

)
@staticmethod
def check_communication_compatibility(
def _check_communication_compatibility(
unity_com_ver: str, python_api_version: str, unity_package_version: str
) -> bool:
unity_communicator_version = StrictVersion(unity_com_ver)

return True
@staticmethod
def get_capabilities_proto() -> UnityRLCapabilitiesProto:
def _get_capabilities_proto() -> UnityRLCapabilitiesProto:
def warn_csharp_base_capabilities(
def _warn_csharp_base_capabilities(
caps: UnityRLCapabilitiesProto, unity_package_ver: str, python_package_ver: str
) -> None:
if not caps.baseRLCapabilities:

:str log_folder: Optional folder to write the Unity Player log file into. Requires absolute path.
"""
atexit.register(self._close)
self.additional_args = additional_args or []
self.no_graphics = no_graphics
self._additional_args = additional_args or []
self._no_graphics = no_graphics
# If base port is not specified, use BASE_ENVIRONMENT_PORT if we have
# an environment, otherwise DEFAULT_EDITOR_PORT
if base_port is None:

self.port = base_port + worker_id
self._port = base_port + worker_id
self.proc1 = None
self.timeout_wait: int = timeout_wait
self.communicator = self.get_communicator(worker_id, base_port, timeout_wait)
self.worker_id = worker_id
self.side_channel_manager = SideChannelManager(side_channels)
self.log_folder = log_folder
self._proc1 = None
self._timeout_wait: int = timeout_wait
self._communicator = self._get_communicator(worker_id, base_port, timeout_wait)
self._worker_id = worker_id
self._side_channel_manager = SideChannelManager(side_channels)
self._log_folder = log_folder
# If the environment name is None, a new environment will not be launched
# and the communicator will directly try to connect to an existing unity environment.

)
if file_name is not None:
try:
self.proc1 = env_utils.launch_executable(
file_name, self.executable_args()
self._proc1 = env_utils.launch_executable(
file_name, self._executable_args()
)
except UnityEnvironmentException:
self._close(0)

f"Listening on port {self.port}. "
f"Listening on port {self._port}. "
f"Start training by pressing the Play button in the Unity Editor."
)
self._loaded = True

communication_version=self.API_VERSION,
package_version=mlagents_envs.__version__,
capabilities=UnityEnvironment.get_capabilities_proto(),
capabilities=UnityEnvironment._get_capabilities_proto(),
aca_output = self.send_academy_parameters(rl_init_parameters_in)
aca_output = self._send_academy_parameters(rl_init_parameters_in)
if not UnityEnvironment.check_communication_compatibility(
if not UnityEnvironment._check_communication_compatibility(
aca_params.communication_version,
UnityEnvironment.API_VERSION,
aca_params.package_version,

UnityEnvironment.warn_csharp_base_capabilities(
UnityEnvironment._warn_csharp_base_capabilities(
aca_params.capabilities,
aca_params.package_version,
UnityEnvironment.API_VERSION,

self._update_behavior_specs(aca_output)
@staticmethod
def get_communicator(worker_id, base_port, timeout_wait):
def _get_communicator(worker_id, base_port, timeout_wait):
def executable_args(self) -> List[str]:
def _executable_args(self) -> List[str]:
if self.no_graphics:
if self._no_graphics:
args += [UnityEnvironment.PORT_COMMAND_LINE_ARG, str(self.port)]
if self.log_folder:
args += [UnityEnvironment._PORT_COMMAND_LINE_ARG, str(self._port)]
if self._log_folder:
self.log_folder, f"Player-{self.worker_id}.log"
self._log_folder, f"Player-{self._worker_id}.log"
args += self.additional_args
args += self._additional_args
return args
def _update_behavior_specs(self, output: UnityOutputProto) -> None:

DecisionSteps.empty(self._env_specs[brain_name]),
TerminalSteps.empty(self._env_specs[brain_name]),
)
self.side_channel_manager.process_side_channel_message(output.side_channel)
self._side_channel_manager.process_side_channel_message(output.side_channel)
outputs = self.communicator.exchange(self._generate_reset_input())
outputs = self._communicator.exchange(self._generate_reset_input())
if outputs is None:
raise UnityCommunicatorStoppedException("Communicator has exited.")
self._update_behavior_specs(outputs)

].create_empty_action(n_agents)
step_input = self._generate_step_input(self._env_actions)
with hierarchical_timer("communicator.exchange"):
outputs = self.communicator.exchange(step_input)
outputs = self._communicator.exchange(step_input)
if outputs is None:
raise UnityCommunicatorStoppedException("Communicator has exited.")
self._update_behavior_specs(outputs)

force-killing it. Defaults to `self.timeout_wait`.
"""
if timeout is None:
timeout = self.timeout_wait
timeout = self._timeout_wait
self.communicator.close()
if self.proc1 is not None:
self._communicator.close()
if self._proc1 is not None:
self.proc1.wait(timeout=timeout)
signal_name = self.returncode_to_signal_name(self.proc1.returncode)
self._proc1.wait(timeout=timeout)
signal_name = self._returncode_to_signal_name(self._proc1.returncode)
return_info = f"Environment shut down with return code {self.proc1.returncode}{signal_name}."
return_info = f"Environment shut down with return code {self._proc1.returncode}{signal_name}."
self.proc1.kill()
self._proc1.kill()
self.proc1 = None
@classmethod
def _flatten(cls, arr: Any) -> List[float]:
"""
Converts arrays to list.
:param arr: numpy vector.
:return: flattened list.
"""
if isinstance(arr, cls.SCALAR_ACTION_TYPES):
arr = [float(arr)]
if isinstance(arr, np.ndarray):
arr = arr.tolist()
if len(arr) == 0:
return arr
if isinstance(arr[0], np.ndarray):
# pylint: disable=no-member
arr = [item for sublist in arr for item in sublist.tolist()]
if isinstance(arr[0], list):
# pylint: disable=not-an-iterable
arr = [item for sublist in arr for item in sublist]
arr = [float(x) for x in arr]
return arr
self._proc1 = None
@timed
def _generate_step_input(

rl_in.agent_actions[b].value.extend([action])
rl_in.command = STEP
rl_in.side_channel = bytes(
self.side_channel_manager.generate_side_channel_messages()
self._side_channel_manager.generate_side_channel_messages()
return self.wrap_unity_input(rl_in)
return self._wrap_unity_input(rl_in)
self.side_channel_manager.generate_side_channel_messages()
self._side_channel_manager.generate_side_channel_messages()
return self.wrap_unity_input(rl_in)
return self._wrap_unity_input(rl_in)
def send_academy_parameters(
def _send_academy_parameters(
return self.communicator.initialize(inputs)
return self._communicator.initialize(inputs)
def wrap_unity_input(rl_input: UnityRLInputProto) -> UnityInputProto:
def _wrap_unity_input(rl_input: UnityRLInputProto) -> UnityInputProto:
def returncode_to_signal_name(returncode: int) -> Optional[str]:
def _returncode_to_signal_name(returncode: int) -> Optional[str]:
"""
Try to convert return codes into their corresponding signal name.
E.g. returncode_to_signal_name(-2) -> "SIGINT"

36
ml-agents-envs/mlagents_envs/tests/test_envs.py


from mlagents_envs.mock_communicator import MockCommunicator
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_handles_bad_filename(get_communicator):
with pytest.raises(UnityEnvironmentException):
UnityEnvironment(" ")

@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_initialization(mock_communicator, mock_launcher):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0

],
)
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_port_defaults(
mock_communicator, mock_launcher, base_port, file_name, expected
):

env = UnityEnvironment(file_name=file_name, worker_id=0, base_port=base_port)
assert expected == env.port
assert expected == env._port
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
args = env.executable_args()
args = env._executable_args()
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_reset(mock_communicator, mock_launcher):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0

@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_step(mock_communicator, mock_launcher):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0

@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
@mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator")
def test_close(mock_communicator, mock_launcher):
comm = MockCommunicator(discrete_action=False, visual_inputs=0)
mock_communicator.return_value = comm

unity_ver = "1.0.0"
python_ver = "1.0.0"
unity_package_version = "0.15.0"
assert UnityEnvironment.check_communication_compatibility(
assert UnityEnvironment._check_communication_compatibility(
assert UnityEnvironment.check_communication_compatibility(
assert UnityEnvironment._check_communication_compatibility(
assert not UnityEnvironment.check_communication_compatibility(
assert not UnityEnvironment._check_communication_compatibility(
assert UnityEnvironment.check_communication_compatibility(
assert UnityEnvironment._check_communication_compatibility(
assert not UnityEnvironment.check_communication_compatibility(
assert not UnityEnvironment._check_communication_compatibility(
assert not UnityEnvironment.check_communication_compatibility(
assert not UnityEnvironment._check_communication_compatibility(
assert UnityEnvironment.returncode_to_signal_name(-2) == "SIGINT"
assert UnityEnvironment.returncode_to_signal_name(42) is None
assert UnityEnvironment.returncode_to_signal_name("SIGINT") is None
assert UnityEnvironment._returncode_to_signal_name(-2) == "SIGINT"
assert UnityEnvironment._returncode_to_signal_name(42) is None
assert UnityEnvironment._returncode_to_signal_name("SIGINT") is None
if __name__ == "__main__":

6
ml-agents/mlagents/trainers/learn.py


os.path.join(base_path, options.run_id) if options.initialize_from else None
)
run_logs_dir = os.path.join(write_path, "run_logs")
port = options.base_port
port: Optional[int] = options.base_port
# Check if directory exists
handle_existing_directories(
write_path, options.resume, options.force, maybe_init_path

StatsReporter.add_writer(console_writer)
if options.env_path is None:
port = UnityEnvironment.DEFAULT_EDITOR_PORT
port = None
env_factory = create_environment_factory(
options.env_path,
options.no_graphics,

env_path: Optional[str],
no_graphics: bool,
seed: int,
start_port: int,
start_port: Optional[int],
env_args: Optional[List[str]],
log_folder: str,
) -> Callable[[int, List[SideChannel]], BaseEnv]:

正在加载...
取消
保存