Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

432 行
17 KiB

from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set
import cloudpickle
import enum
import time
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.exception import (
UnityCommunicationException,
UnityTimeOutException,
UnityEnvironmentException,
UnityCommunicatorStoppedException,
)
from multiprocessing import Process, Pipe, Queue
from multiprocessing.connection import Connection
from queue import Empty as EmptyQueueException
from mlagents_envs.base_env import BaseEnv, BehaviorName, BehaviorSpec
from mlagents_envs import logging_util
from mlagents.trainers.env_manager import EnvManager, EnvironmentStep, AllStepResult
from mlagents.trainers.settings import TrainerSettings
from mlagents_envs.timers import (
TimerNode,
timed,
hierarchical_timer,
reset_timers,
get_timer_root,
)
from mlagents.trainers.settings import ParameterRandomizationSettings, RunOptions
from mlagents.trainers.action_info import ActionInfo
from mlagents_envs.side_channel.environment_parameters_channel import (
EnvironmentParametersChannel,
)
from mlagents_envs.side_channel.engine_configuration_channel import (
EngineConfigurationChannel,
EngineConfig,
)
from mlagents_envs.side_channel.stats_side_channel import (
EnvironmentStats,
StatsSideChannel,
)
from mlagents.training_analytics_side_channel import TrainingAnalyticsSideChannel
from mlagents_envs.side_channel.side_channel import SideChannel
logger = logging_util.get_logger(__name__)
WORKER_SHUTDOWN_TIMEOUT_S = 10
class EnvironmentCommand(enum.Enum):
STEP = 1
BEHAVIOR_SPECS = 2
ENVIRONMENT_PARAMETERS = 3
RESET = 4
CLOSE = 5
ENV_EXITED = 6
CLOSED = 7
TRAINING_STARTED = 8
class EnvironmentRequest(NamedTuple):
cmd: EnvironmentCommand
payload: Any = None
class EnvironmentResponse(NamedTuple):
cmd: EnvironmentCommand
worker_id: int
payload: Any
class StepResponse(NamedTuple):
all_step_result: AllStepResult
timer_root: Optional[TimerNode]
environment_stats: EnvironmentStats
class UnityEnvWorker:
def __init__(self, process: Process, worker_id: int, conn: Connection):
self.process = process
self.worker_id = worker_id
self.conn = conn
self.previous_step: EnvironmentStep = EnvironmentStep.empty(worker_id)
self.previous_all_action_info: Dict[str, ActionInfo] = {}
self.waiting = False
self.closed = False
def send(self, cmd: EnvironmentCommand, payload: Any = None) -> None:
try:
req = EnvironmentRequest(cmd, payload)
self.conn.send(req)
except (BrokenPipeError, EOFError):
raise UnityCommunicationException("UnityEnvironment worker: send failed.")
def recv(self) -> EnvironmentResponse:
try:
response: EnvironmentResponse = self.conn.recv()
if response.cmd == EnvironmentCommand.ENV_EXITED:
env_exception: Exception = response.payload
raise env_exception
return response
except (BrokenPipeError, EOFError):
raise UnityCommunicationException("UnityEnvironment worker: recv failed.")
def request_close(self):
try:
self.conn.send(EnvironmentRequest(EnvironmentCommand.CLOSE))
except (BrokenPipeError, EOFError):
logger.debug(
f"UnityEnvWorker {self.worker_id} got exception trying to close."
)
pass
def worker(
parent_conn: Connection,
step_queue: Queue,
pickled_env_factory: str,
worker_id: int,
run_options: RunOptions,
log_level: int = logging_util.INFO,
) -> None:
env_factory: Callable[
[int, List[SideChannel]], UnityEnvironment
] = cloudpickle.loads(pickled_env_factory)
env_parameters = EnvironmentParametersChannel()
engine_config = EngineConfig(
width=run_options.engine_settings.width,
height=run_options.engine_settings.height,
quality_level=run_options.engine_settings.quality_level,
time_scale=run_options.engine_settings.time_scale,
target_frame_rate=run_options.engine_settings.target_frame_rate,
capture_frame_rate=run_options.engine_settings.capture_frame_rate,
)
engine_configuration_channel = EngineConfigurationChannel()
engine_configuration_channel.set_configuration(engine_config)
stats_channel = StatsSideChannel()
training_analytics_channel: Optional[TrainingAnalyticsSideChannel] = None
if worker_id == 0:
training_analytics_channel = TrainingAnalyticsSideChannel()
env: UnityEnvironment = None
# Set log level. On some platforms, the logger isn't common with the
# main process, so we need to set it again.
logging_util.set_log_level(log_level)
def _send_response(cmd_name: EnvironmentCommand, payload: Any) -> None:
parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload))
def _generate_all_results() -> AllStepResult:
all_step_result: AllStepResult = {}
for brain_name in env.behavior_specs:
all_step_result[brain_name] = env.get_steps(brain_name)
return all_step_result
try:
side_channels = [env_parameters, engine_configuration_channel, stats_channel]
if training_analytics_channel is not None:
side_channels.append(training_analytics_channel)
env = env_factory(worker_id, side_channels)
if (
not env.academy_capabilities
or not env.academy_capabilities.trainingAnalytics
):
# Make sure we don't try to send training analytics if the environment doesn't know how to process
# them. This wouldn't be catastrophic, but would result in unknown SideChannel UUIDs being used.
training_analytics_channel = None
if training_analytics_channel:
training_analytics_channel.environment_initialized(run_options)
while True:
req: EnvironmentRequest = parent_conn.recv()
if req.cmd == EnvironmentCommand.STEP:
all_action_info = req.payload
for brain_name, action_info in all_action_info.items():
if len(action_info.agent_ids) > 0:
env.set_actions(brain_name, action_info.env_action)
env.step()
all_step_result = _generate_all_results()
# The timers in this process are independent from all the processes and the "main" process
# So after we send back the root timer, we can safely clear them.
# Note that we could randomly return timers a fraction of the time if we wanted to reduce
# the data transferred.
# TODO get gauges from the workers and merge them in the main process too.
env_stats = stats_channel.get_and_reset_stats()
step_response = StepResponse(
all_step_result, get_timer_root(), env_stats
)
step_queue.put(
EnvironmentResponse(
EnvironmentCommand.STEP, worker_id, step_response
)
)
reset_timers()
elif req.cmd == EnvironmentCommand.BEHAVIOR_SPECS:
_send_response(EnvironmentCommand.BEHAVIOR_SPECS, env.behavior_specs)
elif req.cmd == EnvironmentCommand.ENVIRONMENT_PARAMETERS:
for k, v in req.payload.items():
if isinstance(v, ParameterRandomizationSettings):
v.apply(k, env_parameters)
elif req.cmd == EnvironmentCommand.TRAINING_STARTED:
behavior_name, trainer_config = req.payload
if training_analytics_channel:
training_analytics_channel.training_started(
behavior_name, trainer_config
)
elif req.cmd == EnvironmentCommand.RESET:
env.reset()
all_step_result = _generate_all_results()
_send_response(EnvironmentCommand.RESET, all_step_result)
elif req.cmd == EnvironmentCommand.CLOSE:
break
except (
KeyboardInterrupt,
UnityCommunicationException,
UnityTimeOutException,
UnityEnvironmentException,
UnityCommunicatorStoppedException,
) as ex:
logger.debug(f"UnityEnvironment worker {worker_id}: environment stopping.")
step_queue.put(
EnvironmentResponse(EnvironmentCommand.ENV_EXITED, worker_id, ex)
)
_send_response(EnvironmentCommand.ENV_EXITED, ex)
except Exception as ex:
logger.exception(
f"UnityEnvironment worker {worker_id}: environment raised an unexpected exception."
)
step_queue.put(
EnvironmentResponse(EnvironmentCommand.ENV_EXITED, worker_id, ex)
)
_send_response(EnvironmentCommand.ENV_EXITED, ex)
finally:
logger.debug(f"UnityEnvironment worker {worker_id} closing.")
if env is not None:
env.close()
logger.debug(f"UnityEnvironment worker {worker_id} done.")
parent_conn.close()
step_queue.put(EnvironmentResponse(EnvironmentCommand.CLOSED, worker_id, None))
step_queue.close()
class SubprocessEnvManager(EnvManager):
def __init__(
self,
env_factory: Callable[[int, List[SideChannel]], BaseEnv],
run_options: RunOptions,
n_env: int = 1,
):
super().__init__()
self.env_workers: List[UnityEnvWorker] = []
self.step_queue: Queue = Queue()
self.workers_alive = 0
for worker_idx in range(n_env):
self.env_workers.append(
self.create_worker(
worker_idx, self.step_queue, env_factory, run_options
)
)
self.workers_alive += 1
@staticmethod
def create_worker(
worker_id: int,
step_queue: Queue,
env_factory: Callable[[int, List[SideChannel]], BaseEnv],
run_options: RunOptions,
) -> UnityEnvWorker:
parent_conn, child_conn = Pipe()
# Need to use cloudpickle for the env factory function since function objects aren't picklable
# on Windows as of Python 3.6.
pickled_env_factory = cloudpickle.dumps(env_factory)
child_process = Process(
target=worker,
args=(
child_conn,
step_queue,
pickled_env_factory,
worker_id,
run_options,
logger.level,
),
)
child_process.start()
return UnityEnvWorker(child_process, worker_id, parent_conn)
def _queue_steps(self) -> None:
for env_worker in self.env_workers:
if not env_worker.waiting:
env_action_info = self._take_step(env_worker.previous_step)
env_worker.previous_all_action_info = env_action_info
env_worker.send(EnvironmentCommand.STEP, env_action_info)
env_worker.waiting = True
def _step(self) -> List[EnvironmentStep]:
# Queue steps for any workers which aren't in the "waiting" state.
self._queue_steps()
worker_steps: List[EnvironmentResponse] = []
step_workers: Set[int] = set()
# Poll the step queue for completed steps from environment workers until we retrieve
# 1 or more, which we will then return as StepInfos
while len(worker_steps) < 1:
try:
while True:
step: EnvironmentResponse = self.step_queue.get_nowait()
if step.cmd == EnvironmentCommand.ENV_EXITED:
env_exception: Exception = step.payload
raise env_exception
self.env_workers[step.worker_id].waiting = False
if step.worker_id not in step_workers:
worker_steps.append(step)
step_workers.add(step.worker_id)
except EmptyQueueException:
pass
step_infos = self._postprocess_steps(worker_steps)
return step_infos
def _reset_env(self, config: Optional[Dict] = None) -> List[EnvironmentStep]:
while any(ew.waiting for ew in self.env_workers):
if not self.step_queue.empty():
step = self.step_queue.get_nowait()
self.env_workers[step.worker_id].waiting = False
# Send config to environment
self.set_env_parameters(config)
# First enqueue reset commands for all workers so that they reset in parallel
for ew in self.env_workers:
ew.send(EnvironmentCommand.RESET, config)
# Next (synchronously) collect the reset observations from each worker in sequence
for ew in self.env_workers:
ew.previous_step = EnvironmentStep(ew.recv().payload, ew.worker_id, {}, {})
return list(map(lambda ew: ew.previous_step, self.env_workers))
def set_env_parameters(self, config: Dict = None) -> None:
"""
Sends environment parameter settings to C# via the
EnvironmentParametersSidehannel for each worker.
:param config: Dict of environment parameter keys and values
"""
for ew in self.env_workers:
ew.send(EnvironmentCommand.ENVIRONMENT_PARAMETERS, config)
def on_training_started(
self, behavior_name: str, trainer_settings: TrainerSettings
) -> None:
"""
Handle traing starting for a new behavior type. Generally nothing is necessary here.
:param behavior_name:
:param trainer_settings:
:return:
"""
for ew in self.env_workers:
ew.send(
EnvironmentCommand.TRAINING_STARTED, (behavior_name, trainer_settings)
)
@property
def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]:
result: Dict[BehaviorName, BehaviorSpec] = {}
for worker in self.env_workers:
worker.send(EnvironmentCommand.BEHAVIOR_SPECS)
result.update(worker.recv().payload)
return result
def close(self) -> None:
logger.debug("SubprocessEnvManager closing.")
for env_worker in self.env_workers:
env_worker.request_close()
# Pull messages out of the queue until every worker has CLOSED or we time out.
deadline = time.time() + WORKER_SHUTDOWN_TIMEOUT_S
while self.workers_alive > 0 and time.time() < deadline:
try:
step: EnvironmentResponse = self.step_queue.get_nowait()
env_worker = self.env_workers[step.worker_id]
if step.cmd == EnvironmentCommand.CLOSED and not env_worker.closed:
env_worker.closed = True
self.workers_alive -= 1
# Discard all other messages.
except EmptyQueueException:
pass
self.step_queue.close()
# Sanity check to kill zombie workers and report an issue if they occur.
if self.workers_alive > 0:
logger.error("SubprocessEnvManager had workers that didn't signal shutdown")
for env_worker in self.env_workers:
if not env_worker.closed and env_worker.process.is_alive():
env_worker.process.terminate()
logger.error(
"A SubprocessEnvManager worker did not shut down correctly so it was forcefully terminated."
)
self.step_queue.join_thread()
def _postprocess_steps(
self, env_steps: List[EnvironmentResponse]
) -> List[EnvironmentStep]:
step_infos = []
timer_nodes = []
for step in env_steps:
payload: StepResponse = step.payload
env_worker = self.env_workers[step.worker_id]
new_step = EnvironmentStep(
payload.all_step_result,
step.worker_id,
env_worker.previous_all_action_info,
payload.environment_stats,
)
step_infos.append(new_step)
env_worker.previous_step = new_step
if payload.timer_root:
timer_nodes.append(payload.timer_root)
if timer_nodes:
with hierarchical_timer("workers") as main_timer_node:
for worker_timer_node in timer_nodes:
main_timer_node.merge(
worker_timer_node, root_name="worker_root", is_parallel=True
)
return step_infos
@timed
def _take_step(self, last_step: EnvironmentStep) -> Dict[BehaviorName, ActionInfo]:
all_action_info: Dict[str, ActionInfo] = {}
for brain_name, step_tuple in last_step.current_all_step_result.items():
if brain_name in self.policies:
all_action_info[brain_name] = self.policies[brain_name].get_action(
step_tuple[0], last_step.worker_id
)
return all_action_info