浏览代码

Some improvements to the UnityEnvironment class (#3939)

* Fix typo

* Made a side channel utils to reduce the complexity of UnityEnvironment

* Added a get_side_channel_dict utils method

* Better executable launcher (unarguably)

* Fixing the broken test

* Addressing comments

* [skip ci] Update ml-agents-envs/mlagents_envs/side_channel/side_channel_manager.py

Co-authored-by: Jonathan Harper <jharper+moar@unity3d.com>

* No catch all

Co-authored-by: Jonathan Harper <jharper+moar@unity3d.com>
/docs-update
GitHub 4 年前
当前提交
812983c0
共有 7 个文件被更改,包括 234 次插入194 次删除
  1. 174
      ml-agents-envs/mlagents_envs/environment.py
  2. 12
      ml-agents-envs/mlagents_envs/tests/test_envs.py
  3. 47
      ml-agents-envs/mlagents_envs/tests/test_side_channel.py
  4. 8
      ml-agents/mlagents/trainers/learn.py
  5. 5
      ml-agents/mlagents/trainers/tests/test_learn.py
  6. 101
      ml-agents-envs/mlagents_envs/env_utils.py
  7. 81
      ml-agents-envs/mlagents_envs/side_channel/side_channel_manager.py

174
ml-agents-envs/mlagents_envs/environment.py


import atexit
from distutils.version import StrictVersion
import glob
import uuid
import numpy as np
import os
import subprocess

from mlagents_envs.logging_util import get_logger
from mlagents_envs.side_channel.side_channel import SideChannel, IncomingMessage
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.side_channel_manager import SideChannelManager
from mlagents_envs import env_utils
from mlagents_envs.base_env import (
BaseEnv,

from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
from .rpc_communicator import RpcCommunicator
from sys import platform
import struct
logger = get_logger(__name__)

return capabilities
@staticmethod
def warn_csharp_base_capabitlities(
def warn_csharp_base_capabilities(
caps: UnityRLCapabilitiesProto, unity_package_ver: str, python_package_ver: str
) -> None:
if not caps.baseRLCapabilities:

self.timeout_wait: int = timeout_wait
self.communicator = self.get_communicator(worker_id, base_port, timeout_wait)
self.worker_id = worker_id
self.side_channels: Dict[uuid.UUID, SideChannel] = {}
if side_channels is not None:
for _sc in side_channels:
if _sc.channel_id in self.side_channels:
raise UnityEnvironmentException(
"There cannot be two side channels with the same channel id {0}.".format(
_sc.channel_id
)
)
self.side_channels[_sc.channel_id] = _sc
self.side_channel_manager = SideChannelManager(side_channels)
self.log_folder = log_folder
# If the environment name is None, a new environment will not be launched

"the worker-id must be 0 in order to connect with the Editor."
)
if file_name is not None:
self.executable_launcher(file_name, no_graphics, additional_args)
try:
self.proc1 = env_utils.launch_executable(
file_name, self.executable_args()
)
except UnityEnvironmentException:
self._close(0)
raise
else:
logger.info(
f"Listening on port {self.port}. "

self._close(0)
UnityEnvironment._raise_version_exception(aca_params.communication_version)
UnityEnvironment.warn_csharp_base_capabitlities(
UnityEnvironment.warn_csharp_base_capabilities(
aca_params.capabilities,
aca_params.package_version,
UnityEnvironment.API_VERSION,

def get_communicator(worker_id, base_port, timeout_wait):
return RpcCommunicator(worker_id, base_port, timeout_wait)
@staticmethod
def validate_environment_path(env_path: str) -> Optional[str]:
# Strip out executable extensions if passed
env_path = (
env_path.strip()
.replace(".app", "")
.replace(".exe", "")
.replace(".x86_64", "")
.replace(".x86", "")
)
true_filename = os.path.basename(os.path.normpath(env_path))
logger.debug("The true file name is {}".format(true_filename))
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
return None
cwd = os.getcwd()
launch_string = None
true_filename = os.path.basename(os.path.normpath(env_path))
if platform == "linux" or platform == "linux2":
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86")
if len(candidates) > 0:
launch_string = candidates[0]
elif platform == "darwin":
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", true_filename)
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(env_path + ".app", "Contents", "MacOS", true_filename)
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) > 0:
launch_string = candidates[0]
elif platform == "win32":
candidates = glob.glob(os.path.join(cwd, env_path + ".exe"))
if len(candidates) == 0:
candidates = glob.glob(env_path + ".exe")
if len(candidates) > 0:
launch_string = candidates[0]
return launch_string
def executable_args(self) -> List[str]:
args: List[str] = []
if self.no_graphics:

args += self.additional_args
return args
def executable_launcher(self, file_name, no_graphics, args):
launch_string = self.validate_environment_path(file_name)
if launch_string is None:
self._close(0)
raise UnityEnvironmentException(
f"Couldn't launch the {file_name} environment. Provided filename does not match any environments."
)
else:
logger.debug("This is the launch string {}".format(launch_string))
# Launch Unity environment
subprocess_args = [launch_string] + self.executable_args()
try:
self.proc1 = subprocess.Popen(
subprocess_args,
# start_new_session=True means that signals to the parent python process
# (e.g. SIGINT from keyboard interrupt) will not be sent to the new process on POSIX platforms.
# This is generally good since we want the environment to have a chance to shutdown,
# but may be undesirable in come cases; if so, we'll add a command-line toggle.
# Note that on Windows, the CTRL_C signal will still be sent.
start_new_session=True,
)
except PermissionError as perm:
# This is likely due to missing read or execute permissions on file.
raise UnityEnvironmentException(
f"Error when trying to launch environment - make sure "
f"permissions are set correctly. For example "
f'"chmod -R 755 {launch_string}"'
) from perm
def _update_behavior_specs(self, output: UnityOutputProto) -> None:
init_output = output.rl_initialization_output
for brain_param in init_output.brain_parameters:

DecisionSteps.empty(self._env_specs[brain_name]),
TerminalSteps.empty(self._env_specs[brain_name]),
)
self._parse_side_channel_message(self.side_channels, output.side_channel)
self.side_channel_manager.process_side_channel_message(output.side_channel)
def reset(self) -> None:
if self._loaded:

arr = [float(x) for x in arr]
return arr
@staticmethod
def _parse_side_channel_message(
side_channels: Dict[uuid.UUID, SideChannel], data: bytes
) -> None:
offset = 0
while offset < len(data):
try:
channel_id = uuid.UUID(bytes_le=bytes(data[offset : offset + 16]))
offset += 16
message_len, = struct.unpack_from("<i", data, offset)
offset = offset + 4
message_data = data[offset : offset + message_len]
offset = offset + message_len
except Exception:
raise UnityEnvironmentException(
"There was a problem reading a message in a SideChannel. "
"Please make sure the version of MLAgents in Unity is "
"compatible with the Python version."
)
if len(message_data) != message_len:
raise UnityEnvironmentException(
"The message received by the side channel {0} was "
"unexpectedly short. Make sure your Unity Environment "
"sending side channel data properly.".format(channel_id)
)
if channel_id in side_channels:
incoming_message = IncomingMessage(message_data)
side_channels[channel_id].on_message_received(incoming_message)
else:
logger.warning(
"Unknown side channel data received. Channel type "
": {0}.".format(channel_id)
)
@staticmethod
def _generate_side_channel_data(
side_channels: Dict[uuid.UUID, SideChannel]
) -> bytearray:
result = bytearray()
for channel_id, channel in side_channels.items():
for message in channel.message_queue:
result += channel_id.bytes_le
result += struct.pack("<i", len(message))
result += message
channel.message_queue = []
return result
@timed
def _generate_step_input(
self, vector_action: Dict[str, np.ndarray]

action = AgentActionProto(vector_actions=vector_action[b][i])
rl_in.agent_actions[b].value.extend([action])
rl_in.command = STEP
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels))
rl_in.side_channel = bytes(
self.side_channel_manager.generate_side_channel_messages()
)
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels))
rl_in.side_channel = bytes(
self.side_channel_manager.generate_side_channel_messages()
)
return self.wrap_unity_input(rl_in)
def send_academy_parameters(

12
ml-agents-envs/mlagents_envs/tests/test_envs.py


UnityEnvironment(" ")
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
def test_initialization(mock_communicator, mock_launcher):
mock_communicator.return_value = MockCommunicator(

(None, None, UnityEnvironment.DEFAULT_EDITOR_PORT),
],
)
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
def test_port_defaults(
mock_communicator, mock_launcher, base_port, file_name, expected

assert expected == env.port
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
def test_log_file_path_is_set(mock_communicator, mock_launcher):
mock_communicator.return_value = MockCommunicator()

assert args[log_file_index + 1] == "./some-log-folder-path/Player-0.log"
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
def test_reset(mock_communicator, mock_launcher):
mock_communicator.return_value = MockCommunicator(

assert (n_agents,) + shape == obs.shape
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
def test_step(mock_communicator, mock_launcher):
mock_communicator.return_value = MockCommunicator(

assert 2 in terminal_steps
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.env_utils.launch_executable")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
def test_close(mock_communicator, mock_launcher):
comm = MockCommunicator(discrete_action=False, visual_inputs=0)

47
ml-agents-envs/mlagents_envs/tests/test_side_channel.py


import uuid
import pytest
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
from mlagents_envs.side_channel.side_channel_manager import SideChannelManager
from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel
from mlagents_envs.side_channel.raw_bytes_channel import RawBytesChannel
from mlagents_envs.side_channel.engine_configuration_channel import (

StatsSideChannel,
StatsAggregationMethod,
)
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.exception import (
UnitySideChannelException,
UnityCommunicationException,

receiver = IntChannel()
sender.send_int(5)
sender.send_int(6)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
assert receiver.list_int[0] == 5
assert receiver.list_int[1] == 6

sender.set_property("prop1", 1.0)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
val = receiver.get_property("prop1")
assert val == 1.0

data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
val = receiver.get_property("prop1")
assert val == 1.0

sender.send_raw_data("foo".encode("ascii"))
sender.send_raw_data("bar".encode("ascii"))
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
messages = receiver.get_and_clear_received_messages()
assert len(messages) == 2

config = EngineConfig.default_config()
sender.set_configuration(config)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
received_data = receiver.get_and_clear_received_messages()
assert len(received_data) == 5 # 5 different messages one for each setting

data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
message = IncomingMessage(receiver.get_and_clear_received_messages()[0])
message.read_int32()

with pytest.raises(UnityCommunicationException):
# try to send data to the EngineConfigurationChannel
sender.set_configuration_parameters(time_scale=sent_time_scale)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_id: sender}, data
)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([sender]).process_side_channel_message(data)
def test_environment_parameters():

sender.set_float_parameter("param-1", 0.1)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
message = IncomingMessage(receiver.get_and_clear_received_messages()[0])
key = message.read_string()

sender.set_float_parameter("param-2", 0.1)
sender.set_float_parameter("param-3", 0.1)
data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message({receiver.channel_id: receiver}, data)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([receiver]).process_side_channel_message(data)
assert len(receiver.get_and_clear_received_messages()) == 3

data = UnityEnvironment._generate_side_channel_data({sender.channel_id: sender})
UnityEnvironment._parse_side_channel_message(
{receiver.channel_id: sender}, data
)
data = SideChannelManager([sender]).generate_side_channel_messages()
SideChannelManager([sender]).process_side_channel_message(data)
def test_stats_channel():

8
ml-agents/mlagents/trainers/learn.py


from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.timers import (
hierarchical_timer,
get_timer_tree,

env_args: Optional[List[str]],
log_folder: str,
) -> Callable[[int, List[SideChannel]], BaseEnv]:
if env_path is not None:
launch_string = UnityEnvironment.validate_environment_path(env_path)
if launch_string is None:
raise UnityEnvironmentException(
f"Couldn't launch the {env_path} environment. Provided filename does not match any environments."
)
def create_unity_environment(
worker_id: int, side_channels: List[SideChannel]
) -> UnityEnvironment:

5
ml-agents/mlagents/trainers/tests/test_learn.py


def test_bad_env_path():
with pytest.raises(UnityEnvironmentException):
learn.create_environment_factory(
factory = learn.create_environment_factory(
seed=None,
seed=-1,
factory(worker_id=-1, side_channels=[])
@patch("builtins.open", new_callable=mock_open, read_data=MOCK_YAML)

101
ml-agents-envs/mlagents_envs/env_utils.py


import glob
import os
import subprocess
from sys import platform
from typing import Optional, List
from mlagents_envs.logging_util import get_logger
from mlagents_envs.exception import UnityEnvironmentException
def validate_environment_path(env_path: str) -> Optional[str]:
"""
Strip out executable extensions of the env_path
:param env_path: The path to the executable
"""
env_path = (
env_path.strip()
.replace(".app", "")
.replace(".exe", "")
.replace(".x86_64", "")
.replace(".x86", "")
)
true_filename = os.path.basename(os.path.normpath(env_path))
get_logger(__name__).debug("The true file name is {}".format(true_filename))
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
return None
cwd = os.getcwd()
launch_string = None
true_filename = os.path.basename(os.path.normpath(env_path))
if platform == "linux" or platform == "linux2":
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86_64")
if len(candidates) == 0:
candidates = glob.glob(env_path + ".x86")
if len(candidates) > 0:
launch_string = candidates[0]
elif platform == "darwin":
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", true_filename)
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(env_path + ".app", "Contents", "MacOS", true_filename)
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) == 0:
candidates = glob.glob(
os.path.join(env_path + ".app", "Contents", "MacOS", "*")
)
if len(candidates) > 0:
launch_string = candidates[0]
elif platform == "win32":
candidates = glob.glob(os.path.join(cwd, env_path + ".exe"))
if len(candidates) == 0:
candidates = glob.glob(env_path + ".exe")
if len(candidates) > 0:
launch_string = candidates[0]
return launch_string
def launch_executable(file_name: str, args: List[str]) -> subprocess.Popen:
"""
Launches a Unity executable and returns the process handle for it.
:param file_name: the name of the executable
:param args: List of string that will be passed as command line arguments
when launching the executable.
"""
launch_string = validate_environment_path(file_name)
if launch_string is None:
raise UnityEnvironmentException(
f"Couldn't launch the {file_name} environment. Provided filename does not match any environments."
)
else:
get_logger(__name__).debug("This is the launch string {}".format(launch_string))
# Launch Unity environment
subprocess_args = [launch_string] + args
try:
return subprocess.Popen(
subprocess_args,
# start_new_session=True means that signals to the parent python process
# (e.g. SIGINT from keyboard interrupt) will not be sent to the new process on POSIX platforms.
# This is generally good since we want the environment to have a chance to shutdown,
# but may be undesirable in come cases; if so, we'll add a command-line toggle.
# Note that on Windows, the CTRL_C signal will still be sent.
start_new_session=True,
)
except PermissionError as perm:
# This is likely due to missing read or execute permissions on file.
raise UnityEnvironmentException(
f"Error when trying to launch environment - make sure "
f"permissions are set correctly. For example "
f'"chmod -R 755 {launch_string}"'
) from perm

81
ml-agents-envs/mlagents_envs/side_channel/side_channel_manager.py


import uuid
import struct
from typing import Dict, Optional, List
from mlagents_envs.side_channel import SideChannel, IncomingMessage
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.logging_util import get_logger
class SideChannelManager:
def __init__(self, side_channels=Optional[List[SideChannel]]):
self._side_channels_dict = self._get_side_channels_dict(side_channels)
def process_side_channel_message(self, data: bytes) -> None:
"""
Separates the data received from Python into individual messages for each
registered side channel and calls on_message_received on them.
:param data: The packed message sent by Unity
"""
offset = 0
while offset < len(data):
try:
channel_id = uuid.UUID(bytes_le=bytes(data[offset : offset + 16]))
offset += 16
message_len, = struct.unpack_from("<i", data, offset)
offset = offset + 4
message_data = data[offset : offset + message_len]
offset = offset + message_len
except (struct.error, ValueError, IndexError):
raise UnityEnvironmentException(
"There was a problem reading a message in a SideChannel. "
"Please make sure the version of MLAgents in Unity is "
"compatible with the Python version."
)
if len(message_data) != message_len:
raise UnityEnvironmentException(
"The message received by the side channel {0} was "
"unexpectedly short. Make sure your Unity Environment "
"sending side channel data properly.".format(channel_id)
)
if channel_id in self._side_channels_dict:
incoming_message = IncomingMessage(message_data)
self._side_channels_dict[channel_id].on_message_received(
incoming_message
)
else:
get_logger(__name__).warning(
f"Unknown side channel data received. Channel type: {channel_id}."
)
def generate_side_channel_messages(self) -> bytearray:
"""
Gathers the messages that the registered side channels will send to Unity
and combines them into a single message ready to be sent.
"""
result = bytearray()
for channel_id, channel in self._side_channels_dict.items():
for message in channel.message_queue:
result += channel_id.bytes_le
result += struct.pack("<i", len(message))
result += message
channel.message_queue = []
return result
@staticmethod
def _get_side_channels_dict(
side_channels: Optional[List[SideChannel]]
) -> Dict[uuid.UUID, SideChannel]:
"""
Converts a list of side channels into a dictionary of channel_id to SideChannel
:param side_channels: The list of side channels.
"""
side_channels_dict: Dict[uuid.UUID, SideChannel] = {}
if side_channels is not None:
for _sc in side_channels:
if _sc.channel_id in side_channels_dict:
raise UnityEnvironmentException(
f"There cannot be two side channels with "
f"the same channel id {_sc.channel_id}."
)
side_channels_dict[_sc.channel_id] = _sc
return side_channels_dict
正在加载...
取消
保存