Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

81 行
3.5 KiB

import uuid
import struct
from typing import Dict, Optional, List
from mlagents_envs.side_channel import SideChannel, IncomingMessage
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.logging_util import get_logger
class SideChannelManager:
def __init__(self, side_channels=Optional[List[SideChannel]]):
self._side_channels_dict = self._get_side_channels_dict(side_channels)
def process_side_channel_message(self, data: bytes) -> None:
"""
Separates the data received from Python into individual messages for each
registered side channel and calls on_message_received on them.
:param data: The packed message sent by Unity
"""
offset = 0
while offset < len(data):
try:
channel_id = uuid.UUID(bytes_le=bytes(data[offset : offset + 16]))
offset += 16
message_len, = struct.unpack_from("<i", data, offset)
offset = offset + 4
message_data = data[offset : offset + message_len]
offset = offset + message_len
except (struct.error, ValueError, IndexError):
raise UnityEnvironmentException(
"There was a problem reading a message in a SideChannel. "
"Please make sure the version of MLAgents in Unity is "
"compatible with the Python version."
)
if len(message_data) != message_len:
raise UnityEnvironmentException(
"The message received by the side channel {} was "
"unexpectedly short. Make sure your Unity Environment "
"sending side channel data properly.".format(channel_id)
)
if channel_id in self._side_channels_dict:
incoming_message = IncomingMessage(message_data)
self._side_channels_dict[channel_id].on_message_received(
incoming_message
)
else:
get_logger(__name__).warning(
f"Unknown side channel data received. Channel type: {channel_id}."
)
def generate_side_channel_messages(self) -> bytearray:
"""
Gathers the messages that the registered side channels will send to Unity
and combines them into a single message ready to be sent.
"""
result = bytearray()
for channel_id, channel in self._side_channels_dict.items():
for message in channel.message_queue:
result += channel_id.bytes_le
result += struct.pack("<i", len(message))
result += message
channel.message_queue = []
return result
@staticmethod
def _get_side_channels_dict(
side_channels: Optional[List[SideChannel]]
) -> Dict[uuid.UUID, SideChannel]:
"""
Converts a list of side channels into a dictionary of channel_id to SideChannel
:param side_channels: The list of side channels.
"""
side_channels_dict: Dict[uuid.UUID, SideChannel] = {}
if side_channels is not None:
for _sc in side_channels:
if _sc.channel_id in side_channels_dict:
raise UnityEnvironmentException(
f"There cannot be two side channels with "
f"the same channel id {_sc.channel_id}."
)
side_channels_dict[_sc.channel_id] = _sc
return side_channels_dict