Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

130 行
4.9 KiB

import logging
import grpc
from typing import Optional
import socket
from multiprocessing import Pipe
from concurrent.futures import ThreadPoolExecutor
from .communicator import Communicator
from mlagents.envs.communicator_objects.unity_to_external_pb2_grpc import (
UnityToExternalProtoServicer,
add_UnityToExternalProtoServicer_to_server,
)
from mlagents.envs.communicator_objects.unity_message_pb2 import UnityMessageProto
from mlagents.envs.communicator_objects.unity_input_pb2 import UnityInputProto
from mlagents.envs.communicator_objects.unity_output_pb2 import UnityOutputProto
from .exception import UnityTimeOutException, UnityWorkerInUseException
logger = logging.getLogger("mlagents.envs")
class UnityToExternalServicerImplementation(UnityToExternalProtoServicer):
def __init__(self):
self.parent_conn, self.child_conn = Pipe()
def Initialize(self, request, context):
self.child_conn.send(request)
return self.child_conn.recv()
def Exchange(self, request, context):
self.child_conn.send(request)
return self.child_conn.recv()
class RpcCommunicator(Communicator):
def __init__(self, worker_id=0, base_port=5005, timeout_wait=30):
"""
Python side of the grpc communication. Python is the server and Unity the client
:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
:int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
"""
self.port = base_port + worker_id
self.worker_id = worker_id
self.timeout_wait = timeout_wait
self.server = None
self.unity_to_external = None
self.is_open = False
self.create_server()
def create_server(self):
"""
Creates the GRPC server.
"""
self.check_port(self.port)
try:
# Establish communication grpc
self.server = grpc.server(ThreadPoolExecutor(max_workers=10))
self.unity_to_external = UnityToExternalServicerImplementation()
add_UnityToExternalProtoServicer_to_server(
self.unity_to_external, self.server
)
# Using unspecified address, which means that grpc is communicating on all IPs
# This is so that the docker container can connect.
self.server.add_insecure_port("[::]:" + str(self.port))
self.server.start()
self.is_open = True
except Exception:
raise UnityWorkerInUseException(self.worker_id)
def check_port(self, port):
"""
Attempts to bind to the requested communicator port, checking if it is already in use.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.bind(("localhost", port))
except socket.error:
raise UnityWorkerInUseException(self.worker_id)
finally:
s.close()
def poll_for_timeout(self):
"""
Polls the GRPC parent connection for data, to be used before calling recv. This prevents
us from hanging indefinitely in the case where the environment process has died or was not
launched.
"""
if not self.unity_to_external.parent_conn.poll(self.timeout_wait):
raise UnityTimeOutException(
"The Unity environment took too long to respond. Make sure that :\n"
"\t The environment does not need user interaction to launch\n"
"\t The Agents are linked to the appropriate Brains\n"
"\t The environment and the Python interface have compatible versions."
)
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
self.poll_for_timeout()
aca_param = self.unity_to_external.parent_conn.recv().unity_output
message = UnityMessageProto()
message.header.status = 200
message.unity_input.CopyFrom(inputs)
self.unity_to_external.parent_conn.send(message)
self.unity_to_external.parent_conn.recv()
return aca_param
def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]:
message = UnityMessageProto()
message.header.status = 200
message.unity_input.CopyFrom(inputs)
self.unity_to_external.parent_conn.send(message)
self.poll_for_timeout()
output = self.unity_to_external.parent_conn.recv()
if output.header.status != 200:
return None
return output.unity_output
def close(self):
"""
Sends a shutdown signal to the unity environment, and closes the grpc connection.
"""
if self.is_open:
message_input = UnityMessageProto()
message_input.header.status = 400
self.unity_to_external.parent_conn.send(message_input)
self.unity_to_external.parent_conn.close()
self.server.stop(False)
self.is_open = False