您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
131 行
5.0 KiB
131 行
5.0 KiB
import logging
|
|
import grpc
|
|
from typing import Optional
|
|
|
|
import socket
|
|
from multiprocessing import Pipe
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from .communicator import Communicator
|
|
from mlagents.envs.communicator_objects.unity_to_external_pb2_grpc import (
|
|
UnityToExternalProtoServicer,
|
|
add_UnityToExternalProtoServicer_to_server,
|
|
)
|
|
from mlagents.envs.communicator_objects.unity_message_pb2 import UnityMessageProto
|
|
from mlagents.envs.communicator_objects.unity_input_pb2 import UnityInputProto
|
|
from mlagents.envs.communicator_objects.unity_output_pb2 import UnityOutputProto
|
|
from .exception import UnityTimeOutException, UnityWorkerInUseException
|
|
|
|
logger = logging.getLogger("mlagents.envs")
|
|
|
|
|
|
class UnityToExternalServicerImplementation(UnityToExternalProtoServicer):
|
|
def __init__(self):
|
|
self.parent_conn, self.child_conn = Pipe()
|
|
|
|
def Initialize(self, request, context):
|
|
self.child_conn.send(request)
|
|
return self.child_conn.recv()
|
|
|
|
def Exchange(self, request, context):
|
|
self.child_conn.send(request)
|
|
return self.child_conn.recv()
|
|
|
|
|
|
class RpcCommunicator(Communicator):
|
|
def __init__(self, worker_id=0, base_port=5005, timeout_wait=30):
|
|
"""
|
|
Python side of the grpc communication. Python is the server and Unity the client
|
|
|
|
|
|
:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
|
|
:int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
|
|
"""
|
|
super().__init__(worker_id, base_port)
|
|
self.port = base_port + worker_id
|
|
self.worker_id = worker_id
|
|
self.timeout_wait = timeout_wait
|
|
self.server = None
|
|
self.unity_to_external = None
|
|
self.is_open = False
|
|
self.create_server()
|
|
|
|
def create_server(self):
|
|
"""
|
|
Creates the GRPC server.
|
|
"""
|
|
self.check_port(self.port)
|
|
|
|
try:
|
|
# Establish communication grpc
|
|
self.server = grpc.server(ThreadPoolExecutor(max_workers=10))
|
|
self.unity_to_external = UnityToExternalServicerImplementation()
|
|
add_UnityToExternalProtoServicer_to_server(
|
|
self.unity_to_external, self.server
|
|
)
|
|
# Using unspecified address, which means that grpc is communicating on all IPs
|
|
# This is so that the docker container can connect.
|
|
self.server.add_insecure_port("[::]:" + str(self.port))
|
|
self.server.start()
|
|
self.is_open = True
|
|
except Exception:
|
|
raise UnityWorkerInUseException(self.worker_id)
|
|
|
|
def check_port(self, port):
|
|
"""
|
|
Attempts to bind to the requested communicator port, checking if it is already in use.
|
|
"""
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
try:
|
|
s.bind(("localhost", port))
|
|
except socket.error:
|
|
raise UnityWorkerInUseException(self.worker_id)
|
|
finally:
|
|
s.close()
|
|
|
|
def poll_for_timeout(self):
|
|
"""
|
|
Polls the GRPC parent connection for data, to be used before calling recv. This prevents
|
|
us from hanging indefinitely in the case where the environment process has died or was not
|
|
launched.
|
|
"""
|
|
if not self.unity_to_external.parent_conn.poll(self.timeout_wait):
|
|
raise UnityTimeOutException(
|
|
"The Unity environment took too long to respond. Make sure that :\n"
|
|
"\t The environment does not need user interaction to launch\n"
|
|
"\t The Agents are linked to the appropriate Brains\n"
|
|
"\t The environment and the Python interface have compatible versions."
|
|
)
|
|
|
|
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
|
|
self.poll_for_timeout()
|
|
aca_param = self.unity_to_external.parent_conn.recv().unity_output
|
|
message = UnityMessageProto()
|
|
message.header.status = 200
|
|
message.unity_input.CopyFrom(inputs)
|
|
self.unity_to_external.parent_conn.send(message)
|
|
self.unity_to_external.parent_conn.recv()
|
|
return aca_param
|
|
|
|
def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]:
|
|
message = UnityMessageProto()
|
|
message.header.status = 200
|
|
message.unity_input.CopyFrom(inputs)
|
|
self.unity_to_external.parent_conn.send(message)
|
|
self.poll_for_timeout()
|
|
output = self.unity_to_external.parent_conn.recv()
|
|
if output.header.status != 200:
|
|
return None
|
|
return output.unity_output
|
|
|
|
def close(self):
|
|
"""
|
|
Sends a shutdown signal to the unity environment, and closes the grpc connection.
|
|
"""
|
|
if self.is_open:
|
|
message_input = UnityMessageProto()
|
|
message_input.header.status = 400
|
|
self.unity_to_external.parent_conn.send(message_input)
|
|
self.unity_to_external.parent_conn.close()
|
|
self.server.stop(False)
|
|
self.is_open = False
|