浏览代码
Some improvements to the UnityEnvironment class (#3939)
Some improvements to the UnityEnvironment class (#3939)
* Fix typo * Made a side channel utils to reduce the complexity of UnityEnvironment * Added a get_side_channel_dict utils method * Better executable launcher (unarguably) * Fixing the broken test * Addressing comments * [skip ci] Update ml-agents-envs/mlagents_envs/side_channel/side_channel_manager.py Co-authored-by: Jonathan Harper <jharper+moar@unity3d.com> * No catch all Co-authored-by: Jonathan Harper <jharper+moar@unity3d.com>/docs-update
GitHub
5 年前
当前提交
812983c0
共有 7 个文件被更改,包括 234 次插入 和 194 次删除
-
174ml-agents-envs/mlagents_envs/environment.py
-
12ml-agents-envs/mlagents_envs/tests/test_envs.py
-
47ml-agents-envs/mlagents_envs/tests/test_side_channel.py
-
8ml-agents/mlagents/trainers/learn.py
-
5ml-agents/mlagents/trainers/tests/test_learn.py
-
101ml-agents-envs/mlagents_envs/env_utils.py
-
81ml-agents-envs/mlagents_envs/side_channel/side_channel_manager.py
|
|||
import glob |
|||
import os |
|||
import subprocess |
|||
from sys import platform |
|||
from typing import Optional, List |
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents_envs.exception import UnityEnvironmentException |
|||
|
|||
|
|||
def validate_environment_path(env_path: str) -> Optional[str]: |
|||
""" |
|||
Strip out executable extensions of the env_path |
|||
:param env_path: The path to the executable |
|||
""" |
|||
env_path = ( |
|||
env_path.strip() |
|||
.replace(".app", "") |
|||
.replace(".exe", "") |
|||
.replace(".x86_64", "") |
|||
.replace(".x86", "") |
|||
) |
|||
true_filename = os.path.basename(os.path.normpath(env_path)) |
|||
get_logger(__name__).debug("The true file name is {}".format(true_filename)) |
|||
|
|||
if not (glob.glob(env_path) or glob.glob(env_path + ".*")): |
|||
return None |
|||
|
|||
cwd = os.getcwd() |
|||
launch_string = None |
|||
true_filename = os.path.basename(os.path.normpath(env_path)) |
|||
if platform == "linux" or platform == "linux2": |
|||
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86_64") |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86") |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob(env_path + ".x86_64") |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob(env_path + ".x86") |
|||
if len(candidates) > 0: |
|||
launch_string = candidates[0] |
|||
|
|||
elif platform == "darwin": |
|||
candidates = glob.glob( |
|||
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", true_filename) |
|||
) |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob( |
|||
os.path.join(env_path + ".app", "Contents", "MacOS", true_filename) |
|||
) |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob( |
|||
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", "*") |
|||
) |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob( |
|||
os.path.join(env_path + ".app", "Contents", "MacOS", "*") |
|||
) |
|||
if len(candidates) > 0: |
|||
launch_string = candidates[0] |
|||
elif platform == "win32": |
|||
candidates = glob.glob(os.path.join(cwd, env_path + ".exe")) |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob(env_path + ".exe") |
|||
if len(candidates) > 0: |
|||
launch_string = candidates[0] |
|||
return launch_string |
|||
|
|||
|
|||
def launch_executable(file_name: str, args: List[str]) -> subprocess.Popen: |
|||
""" |
|||
Launches a Unity executable and returns the process handle for it. |
|||
:param file_name: the name of the executable |
|||
:param args: List of string that will be passed as command line arguments |
|||
when launching the executable. |
|||
""" |
|||
launch_string = validate_environment_path(file_name) |
|||
if launch_string is None: |
|||
raise UnityEnvironmentException( |
|||
f"Couldn't launch the {file_name} environment. Provided filename does not match any environments." |
|||
) |
|||
else: |
|||
get_logger(__name__).debug("This is the launch string {}".format(launch_string)) |
|||
# Launch Unity environment |
|||
subprocess_args = [launch_string] + args |
|||
try: |
|||
return subprocess.Popen( |
|||
subprocess_args, |
|||
# start_new_session=True means that signals to the parent python process |
|||
# (e.g. SIGINT from keyboard interrupt) will not be sent to the new process on POSIX platforms. |
|||
# This is generally good since we want the environment to have a chance to shutdown, |
|||
# but may be undesirable in come cases; if so, we'll add a command-line toggle. |
|||
# Note that on Windows, the CTRL_C signal will still be sent. |
|||
start_new_session=True, |
|||
) |
|||
except PermissionError as perm: |
|||
# This is likely due to missing read or execute permissions on file. |
|||
raise UnityEnvironmentException( |
|||
f"Error when trying to launch environment - make sure " |
|||
f"permissions are set correctly. For example " |
|||
f'"chmod -R 755 {launch_string}"' |
|||
) from perm |
|
|||
import uuid |
|||
import struct |
|||
from typing import Dict, Optional, List |
|||
from mlagents_envs.side_channel import SideChannel, IncomingMessage |
|||
from mlagents_envs.exception import UnityEnvironmentException |
|||
from mlagents_envs.logging_util import get_logger |
|||
|
|||
|
|||
class SideChannelManager: |
|||
def __init__(self, side_channels=Optional[List[SideChannel]]): |
|||
self._side_channels_dict = self._get_side_channels_dict(side_channels) |
|||
|
|||
def process_side_channel_message(self, data: bytes) -> None: |
|||
""" |
|||
Separates the data received from Python into individual messages for each |
|||
registered side channel and calls on_message_received on them. |
|||
:param data: The packed message sent by Unity |
|||
""" |
|||
offset = 0 |
|||
while offset < len(data): |
|||
try: |
|||
channel_id = uuid.UUID(bytes_le=bytes(data[offset : offset + 16])) |
|||
offset += 16 |
|||
message_len, = struct.unpack_from("<i", data, offset) |
|||
offset = offset + 4 |
|||
message_data = data[offset : offset + message_len] |
|||
offset = offset + message_len |
|||
except (struct.error, ValueError, IndexError): |
|||
raise UnityEnvironmentException( |
|||
"There was a problem reading a message in a SideChannel. " |
|||
"Please make sure the version of MLAgents in Unity is " |
|||
"compatible with the Python version." |
|||
) |
|||
if len(message_data) != message_len: |
|||
raise UnityEnvironmentException( |
|||
"The message received by the side channel {0} was " |
|||
"unexpectedly short. Make sure your Unity Environment " |
|||
"sending side channel data properly.".format(channel_id) |
|||
) |
|||
if channel_id in self._side_channels_dict: |
|||
incoming_message = IncomingMessage(message_data) |
|||
self._side_channels_dict[channel_id].on_message_received( |
|||
incoming_message |
|||
) |
|||
else: |
|||
get_logger(__name__).warning( |
|||
f"Unknown side channel data received. Channel type: {channel_id}." |
|||
) |
|||
|
|||
def generate_side_channel_messages(self) -> bytearray: |
|||
""" |
|||
Gathers the messages that the registered side channels will send to Unity |
|||
and combines them into a single message ready to be sent. |
|||
""" |
|||
result = bytearray() |
|||
for channel_id, channel in self._side_channels_dict.items(): |
|||
for message in channel.message_queue: |
|||
result += channel_id.bytes_le |
|||
result += struct.pack("<i", len(message)) |
|||
result += message |
|||
channel.message_queue = [] |
|||
return result |
|||
|
|||
@staticmethod |
|||
def _get_side_channels_dict( |
|||
side_channels: Optional[List[SideChannel]] |
|||
) -> Dict[uuid.UUID, SideChannel]: |
|||
""" |
|||
Converts a list of side channels into a dictionary of channel_id to SideChannel |
|||
:param side_channels: The list of side channels. |
|||
""" |
|||
side_channels_dict: Dict[uuid.UUID, SideChannel] = {} |
|||
if side_channels is not None: |
|||
for _sc in side_channels: |
|||
if _sc.channel_id in side_channels_dict: |
|||
raise UnityEnvironmentException( |
|||
f"There cannot be two side channels with " |
|||
f"the same channel id {_sc.channel_id}." |
|||
) |
|||
side_channels_dict[_sc.channel_id] = _sc |
|||
return side_channels_dict |
撰写
预览
正在加载...
取消
保存
Reference in new issue