浏览代码

[MLA-1712] Make UnityEnvironment fail fast if the env crashes (#4880)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
2af86534
共有 11 个文件被更改,包括 167 次插入41 次删除
  1. 3
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  2. 2
      Project/ProjectSettings/UnityConnectSettings.asset
  3. 8
      com.unity.ml-agents/CHANGELOG.md
  4. 6
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  5. 17
      ml-agents-envs/mlagents_envs/communicator.py
  6. 8
      ml-agents-envs/mlagents_envs/env_utils.py
  7. 49
      ml-agents-envs/mlagents_envs/environment.py
  8. 12
      ml-agents-envs/mlagents_envs/mock_communicator.py
  9. 47
      ml-agents-envs/mlagents_envs/rpc_communicator.py
  10. 54
      ml-agents-envs/mlagents_envs/tests/test_rpc_communicator.py
  11. 2
      ml-agents/mlagents/trainers/subprocess_env_manager.py

3
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using UnityEngine.Rendering;
using UnityEngine.Serialization;
public class GridAgent : Agent

void WaitTimeInference()
{
if (renderCamera != null)
if (renderCamera != null && SystemInfo.graphicsDeviceType != GraphicsDeviceType.Null)
{
renderCamera.Render();
}

2
Project/ProjectSettings/UnityConnectSettings.asset


UnityConnectSettings:
m_ObjectHideFlags: 0
serializedVersion: 1
m_Enabled: 1
m_Enabled: 0
m_TestMode: 0
m_EventOldUrl: https://api.uca.cloud.unity3d.com/v1/events
m_EventUrl: https://cdp.cloud.unity3d.com/v1/events

8
com.unity.ml-agents/CHANGELOG.md


Updated the Basic example and the Match3 Example to use Actuators.
Changed the namespace and file names of classes in com.unity.ml-agents.extensions. (#4849)
- CameraSensor now logs an error if the GraphicsDevice is null. (#4880)
- Fixed a bug that would cause `UnityEnvironment` to wait the full timeout
period and report a misleading error message if the executable crashed
without closing the connection. It now periodically checks the process status
while waiting for a connection, and raises a better error message if it crashes. (#4880)
- Passing a `-logfile` option in the `--env-args` option to `mlagents-learn` is
no longer overwritten. (#4880)
## [1.7.2-preview] - 2020-12-22

6
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


using UnityEngine;
using UnityEngine.Rendering;
namespace Unity.MLAgents.Sensors
{

/// <returns name="texture2D">Texture2D to render to.</returns>
public static Texture2D ObservationToTexture(Camera obsCamera, int width, int height)
{
if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null)
{
Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render.");
}
var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false);
var oldRec = obsCamera.rect;
obsCamera.rect = new Rect(0f, 0f, 1f, 1f);

17
ml-agents-envs/mlagents_envs/communicator.py


from typing import Optional
from typing import Callable, Optional
# Function to call while waiting for a connection timeout.
# This should raise an exception if it needs to break from waiting for the timeout.
PollCallback = Callable[[], None]
class Communicator:

:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
"""
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
def initialize(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> UnityOutputProto:
:param poll_callback: Optional callback to be used while polling the connection.
def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]:
def exchange(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> Optional[UnityOutputProto]:
:param poll_callback: Optional callback to be used while polling the connection.
:return: The UnityOutputs generated by the Environment
"""

8
ml-agents-envs/mlagents_envs/env_utils.py


from mlagents_envs.exception import UnityEnvironmentException
logger = get_logger(__name__)
def get_platform():
"""
returns the platform of the operating system : linux, darwin or win32

.replace(".x86", "")
)
true_filename = os.path.basename(os.path.normpath(env_path))
get_logger(__name__).debug(f"The true file name is {true_filename}")
logger.debug(f"The true file name is {true_filename}")
if not (glob.glob(env_path) or glob.glob(env_path + ".*")):
return None

f"Couldn't launch the {file_name} environment. Provided filename does not match any environments."
)
else:
get_logger(__name__).debug(f"This is the launch string {launch_string}")
logger.debug(f"The launch string is {launch_string}")
logger.debug(f"Running with args {args}")
# Launch Unity environment
subprocess_args = [launch_string] + args
try:

49
ml-agents-envs/mlagents_envs/environment.py


# If true, this means the environment was successfully loaded
self._loaded = False
# The process that is started. If None, no process was started
self._proc1 = None
self._process: Optional[subprocess.Popen] = None
self._timeout_wait: int = timeout_wait
self._communicator = self._get_communicator(worker_id, base_port, timeout_wait)
self._worker_id = worker_id

)
if file_name is not None:
try:
self._proc1 = env_utils.launch_executable(
self._process = env_utils.launch_executable(
file_name, self._executable_args()
)
except UnityEnvironmentException:

if self._no_graphics:
args += ["-nographics", "-batchmode"]
args += [UnityEnvironment._PORT_COMMAND_LINE_ARG, str(self._port)]
if self._log_folder:
# If the logfile arg isn't already set in the env args,
# try to set it to an output directory
logfile_set = "-logfile" in (arg.lower() for arg in self._additional_args)
if self._log_folder and not logfile_set:
log_file_path = os.path.join(
self._log_folder, f"Player-{self._worker_id}.log"
)

def reset(self) -> None:
if self._loaded:
outputs = self._communicator.exchange(self._generate_reset_input())
outputs = self._communicator.exchange(
self._generate_reset_input(), self._poll_process
)
if outputs is None:
raise UnityCommunicatorStoppedException("Communicator has exited.")
self._update_behavior_specs(outputs)

].action_spec.empty_action(n_agents)
step_input = self._generate_step_input(self._env_actions)
with hierarchical_timer("communicator.exchange"):
outputs = self._communicator.exchange(step_input)
outputs = self._communicator.exchange(step_input, self._poll_process)
if outputs is None:
raise UnityCommunicatorStoppedException("Communicator has exited.")
self._update_behavior_specs(outputs)

self._assert_behavior_exists(behavior_name)
return self._env_state[behavior_name]
def _poll_process(self) -> None:
"""
Check the status of the subprocess. If it has exited, raise a UnityEnvironmentException
:return: None
"""
if not self._process:
return
poll_res = self._process.poll()
if poll_res is not None:
exc_msg = self._returncode_to_env_message(self._process.returncode)
raise UnityEnvironmentException(exc_msg)
def close(self):
"""
Sends a shutdown signal to the unity environment, and closes the socket connection.

timeout = self._timeout_wait
self._loaded = False
self._communicator.close()
if self._proc1 is not None:
if self._process is not None:
self._proc1.wait(timeout=timeout)
signal_name = self._returncode_to_signal_name(self._proc1.returncode)
signal_name = f" ({signal_name})" if signal_name else ""
return_info = f"Environment shut down with return code {self._proc1.returncode}{signal_name}."
logger.info(return_info)
self._process.wait(timeout=timeout)
logger.info(self._returncode_to_env_message(self._process.returncode))
self._proc1.kill()
self._process.kill()
self._proc1 = None
self._process = None
@timed
def _generate_step_input(

) -> UnityOutputProto:
inputs = UnityInputProto()
inputs.rl_initialization_input.CopyFrom(init_parameters)
return self._communicator.initialize(inputs)
return self._communicator.initialize(inputs, self._poll_process)
@staticmethod
def _wrap_unity_input(rl_input: UnityRLInputProto) -> UnityInputProto:

except Exception:
# Should generally be a ValueError, but catch everything just in case.
return None
@staticmethod
def _returncode_to_env_message(returncode: int) -> str:
signal_name = UnityEnvironment._returncode_to_signal_name(returncode)
signal_name = f" ({signal_name})" if signal_name else ""
return f"Environment shut down with return code {returncode}{signal_name}."

12
ml-agents-envs/mlagents_envs/mock_communicator.py


from .communicator import Communicator
from typing import Optional
from .communicator import Communicator, PollCallback
from .environment import UnityEnvironment
from mlagents_envs.communicator_objects.unity_rl_output_pb2 import UnityRLOutputProto
from mlagents_envs.communicator_objects.brain_parameters_pb2 import (

self.brain_name = brain_name
self.vec_obs_size = vec_obs_size
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
def initialize(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> UnityOutputProto:
if self.is_discrete:
action_spec = ActionSpecProto(
num_discrete_actions=2, discrete_branch_sizes=[3, 2]

)
return dict_agent_info
def exchange(self, inputs: UnityInputProto) -> UnityOutputProto:
def exchange(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> UnityOutputProto:
result = UnityRLOutputProto(agentInfos=self._get_agent_infos())
return UnityOutputProto(rl_output=result)

47
ml-agents-envs/mlagents_envs/rpc_communicator.py


import grpc
from typing import Optional
from multiprocessing import Pipe
from multiprocessing import Pipe
import time
from .communicator import Communicator
from .communicator import Communicator, PollCallback
from mlagents_envs.communicator_objects.unity_to_external_pb2_grpc import (
UnityToExternalProtoServicer,
add_UnityToExternalProtoServicer_to_server,

finally:
s.close()
def poll_for_timeout(self):
def poll_for_timeout(self, poll_callback: Optional[PollCallback] = None) -> None:
Additionally, a callback can be passed to periodically check the state of the environment.
This is used to detect the case when the environment dies without cleaning up the connection,
so that we can stop sooner and raise a more appropriate error.
if not self.unity_to_external.parent_conn.poll(self.timeout_wait):
raise UnityTimeOutException(
"The Unity environment took too long to respond. Make sure that :\n"
"\t The environment does not need user interaction to launch\n"
'\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n'
"\t The environment and the Python interface have compatible versions."
)
deadline = time.monotonic() + self.timeout_wait
callback_timeout_wait = self.timeout_wait // 10
while time.monotonic() < deadline:
if self.unity_to_external.parent_conn.poll(callback_timeout_wait):
# Got an acknowledgment from the connection
return
if poll_callback:
# Fire the callback - if it detects something wrong, it should raise an exception.
poll_callback()
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:
self.poll_for_timeout()
# Got this far without reading any data from the connection, so it must be dead.
raise UnityTimeOutException(
"The Unity environment took too long to respond. Make sure that :\n"
"\t The environment does not need user interaction to launch\n"
'\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n'
"\t The environment and the Python interface have compatible versions."
)
def initialize(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> UnityOutputProto:
self.poll_for_timeout(poll_callback)
aca_param = self.unity_to_external.parent_conn.recv().unity_output
message = UnityMessageProto()
message.header.status = 200

return aca_param
def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]:
def exchange(
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None
) -> Optional[UnityOutputProto]:
self.poll_for_timeout()
self.poll_for_timeout(poll_callback)
output = self.unity_to_external.parent_conn.recv()
if output.header.status != 200:
return None

54
ml-agents-envs/mlagents_envs/tests/test_rpc_communicator.py


import pytest
from unittest import mock
import grpc
import mlagents_envs.rpc_communicator
from mlagents_envs.exception import UnityWorkerInUseException
from mlagents_envs.exception import (
UnityWorkerInUseException,
UnityTimeOutException,
UnityEnvironmentException,
)
from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto
def test_rpc_communicator_checks_port_on_create():

second_comm = RpcCommunicator(worker_id=1)
first_comm.close()
second_comm.close()
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_OK(mock_impl, mock_grpc_server):
comm = RpcCommunicator(timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = True
input = UnityInputProto()
comm.initialize(input)
comm.unity_to_external.parent_conn.poll.assert_called()
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_timeout(mock_impl, mock_grpc_server):
comm = RpcCommunicator(timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = None
input = UnityInputProto()
# Expect a timeout
with pytest.raises(UnityTimeOutException):
comm.initialize(input)
comm.unity_to_external.parent_conn.poll.assert_called()
@mock.patch.object(grpc, "server")
@mock.patch.object(
mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation"
)
def test_rpc_communicator_initialize_callback(mock_impl, mock_grpc_server):
def callback():
raise UnityEnvironmentException
comm = RpcCommunicator(timeout_wait=0.25)
comm.unity_to_external.parent_conn.poll.return_value = None
input = UnityInputProto()
# Expect a timeout
with pytest.raises(UnityEnvironmentException):
comm.initialize(input, poll_callback=callback)
comm.unity_to_external.parent_conn.poll.assert_called()

2
ml-agents/mlagents/trainers/subprocess_env_manager.py


)
_send_response(EnvironmentCommand.ENV_EXITED, ex)
except Exception as ex:
logger.error(
logger.exception(
f"UnityEnvironment worker {worker_id}: environment raised an unexpected exception."
)
step_queue.put(

正在加载...
取消
保存