您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
82 行
2.6 KiB
82 行
2.6 KiB
import pytest
|
|
from unittest import mock
|
|
|
|
import grpc
|
|
|
|
import mlagents_envs.rpc_communicator
|
|
from mlagents_envs.rpc_communicator import RpcCommunicator
|
|
from mlagents_envs.exception import (
|
|
UnityWorkerInUseException,
|
|
UnityTimeOutException,
|
|
UnityEnvironmentException,
|
|
)
|
|
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
|
|
|
|
|
|
def test_rpc_communicator_checks_port_on_create():
|
|
first_comm = RpcCommunicator()
|
|
with pytest.raises(UnityWorkerInUseException):
|
|
second_comm = RpcCommunicator()
|
|
second_comm.close()
|
|
first_comm.close()
|
|
|
|
|
|
def test_rpc_communicator_close():
|
|
# Ensures it is possible to open a new RPC Communicators
|
|
# after closing one on the same worker_id
|
|
first_comm = RpcCommunicator()
|
|
first_comm.close()
|
|
second_comm = RpcCommunicator()
|
|
second_comm.close()
|
|
|
|
|
|
def test_rpc_communicator_create_multiple_workers():
|
|
# Ensures multiple RPC communicators can be created with
|
|
# different worker_ids without causing an error.
|
|
first_comm = RpcCommunicator()
|
|
second_comm = RpcCommunicator(worker_id=1)
|
|
first_comm.close()
|
|
second_comm.close()
|
|
|
|
|
|
@mock.patch.object(grpc, "server")
|
|
@mock.patch.object(
|
|
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
|
|
)
|
|
def test_rpc_communicator_initialize_OK(mock_impl, mock_grpc_server):
|
|
comm = RpcCommunicator(timeout_wait=0.25)
|
|
comm.unity_to_external.parent_conn.poll.return_value = True
|
|
input = UnityInputProto()
|
|
comm.initialize(input)
|
|
comm.unity_to_external.parent_conn.poll.assert_called()
|
|
|
|
|
|
@mock.patch.object(grpc, "server")
|
|
@mock.patch.object(
|
|
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
|
|
)
|
|
def test_rpc_communicator_initialize_timeout(mock_impl, mock_grpc_server):
|
|
comm = RpcCommunicator(timeout_wait=0.25)
|
|
comm.unity_to_external.parent_conn.poll.return_value = None
|
|
input = UnityInputProto()
|
|
# Expect a timeout
|
|
with pytest.raises(UnityTimeOutException):
|
|
comm.initialize(input)
|
|
comm.unity_to_external.parent_conn.poll.assert_called()
|
|
|
|
|
|
@mock.patch.object(grpc, "server")
|
|
@mock.patch.object(
|
|
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
|
|
)
|
|
def test_rpc_communicator_initialize_callback(mock_impl, mock_grpc_server):
|
|
def callback():
|
|
raise UnityEnvironmentException
|
|
|
|
comm = RpcCommunicator(timeout_wait=0.25)
|
|
comm.unity_to_external.parent_conn.poll.return_value = None
|
|
input = UnityInputProto()
|
|
# Expect a timeout
|
|
with pytest.raises(UnityEnvironmentException):
|
|
comm.initialize(input, poll_callback=callback)
|
|
comm.unity_to_external.parent_conn.poll.assert_called()
|