您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
102 行
4.1 KiB
102 行
4.1 KiB
import logging
|
|
import socket
|
|
import struct
|
|
from typing import Optional
|
|
|
|
from .communicator import Communicator
|
|
from mlagents.envs.communicator_objects.unity_message_pb2 import UnityMessage
|
|
from mlagents.envs.communicator_objects.unity_output_pb2 import UnityOutput
|
|
from mlagents.envs.communicator_objects.unity_input_pb2 import UnityInput
|
|
from .exception import UnityTimeOutException
|
|
|
|
|
|
logger = logging.getLogger("mlagents.envs")
|
|
|
|
|
|
class SocketCommunicator(Communicator):
|
|
def __init__(self, worker_id=0, base_port=5005):
|
|
"""
|
|
Python side of the socket communication
|
|
|
|
:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
|
|
:int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
|
|
"""
|
|
|
|
self.port = base_port + worker_id
|
|
self._buffer_size = 12000
|
|
self.worker_id = worker_id
|
|
self._socket = None
|
|
self._conn = None
|
|
|
|
def initialize(self, inputs: UnityInput) -> UnityOutput:
|
|
try:
|
|
# Establish communication socket
|
|
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
self._socket.bind(("localhost", self.port))
|
|
except Exception:
|
|
raise UnityTimeOutException(
|
|
"Couldn't start socket communication because worker number {} is still in use. "
|
|
"You may need to manually close a previously opened environment "
|
|
"or use a different worker number.".format(str(self.worker_id))
|
|
)
|
|
try:
|
|
self._socket.settimeout(30)
|
|
self._socket.listen(1)
|
|
self._conn, _ = self._socket.accept()
|
|
self._conn.settimeout(30)
|
|
except Exception:
|
|
raise UnityTimeOutException(
|
|
"The Unity environment took too long to respond. Make sure that :\n"
|
|
"\t The environment does not need user interaction to launch\n"
|
|
"\t The Academy's Broadcast Hub is configured correctly\n"
|
|
"\t The Agents are linked to the appropriate Brains\n"
|
|
"\t The environment and the Python interface have compatible versions."
|
|
)
|
|
message = UnityMessage()
|
|
message.header.status = 200
|
|
message.unity_input.CopyFrom(inputs)
|
|
self._communicator_send(message.SerializeToString())
|
|
initialization_output = UnityMessage()
|
|
initialization_output.ParseFromString(self._communicator_receive())
|
|
return initialization_output.unity_output
|
|
|
|
def _communicator_receive(self):
|
|
try:
|
|
s = self._conn.recv(self._buffer_size)
|
|
message_length = struct.unpack("I", bytearray(s[:4]))[0]
|
|
s = s[4:]
|
|
while len(s) != message_length:
|
|
s += self._conn.recv(self._buffer_size)
|
|
except socket.timeout:
|
|
raise UnityTimeOutException("The environment took too long to respond.")
|
|
return s
|
|
|
|
def _communicator_send(self, message):
|
|
self._conn.send(struct.pack("I", len(message)) + message)
|
|
|
|
def exchange(self, inputs: UnityInput) -> Optional[UnityOutput]:
|
|
message = UnityMessage()
|
|
message.header.status = 200
|
|
message.unity_input.CopyFrom(inputs)
|
|
self._communicator_send(message.SerializeToString())
|
|
outputs = UnityMessage()
|
|
outputs.ParseFromString(self._communicator_receive())
|
|
if outputs.header.status != 200:
|
|
return None
|
|
return outputs.unity_output
|
|
|
|
def close(self):
|
|
"""
|
|
Sends a shutdown signal to the unity environment, and closes the socket connection.
|
|
"""
|
|
if self._socket is not None and self._conn is not None:
|
|
message_input = UnityMessage()
|
|
message_input.header.status = 400
|
|
self._communicator_send(message_input.SerializeToString())
|
|
if self._socket is not None:
|
|
self._socket.close()
|
|
self._socket = None
|
|
if self._socket is not None:
|
|
self._conn.close()
|
|
self._conn = None
|