Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

270 行
10 KiB

import logging
from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set
import cloudpickle
from mlagents.envs.environment import UnityEnvironment
from mlagents.envs.exception import UnityCommunicationException
from multiprocessing import Process, Pipe, Queue
from multiprocessing.connection import Connection
from queue import Empty as EmptyQueueException
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.env_manager import EnvManager, EnvironmentStep
from mlagents.envs.timers import (
TimerNode,
timed,
hierarchical_timer,
reset_timers,
get_timer_root,
)
from mlagents.envs.brain import AllBrainInfo, BrainParameters
from mlagents.envs.action_info import ActionInfo
logger = logging.getLogger("mlagents.envs")
class EnvironmentCommand(NamedTuple):
name: str
payload: Any = None
class EnvironmentResponse(NamedTuple):
name: str
worker_id: int
payload: Any
class StepResponse(NamedTuple):
all_brain_info: AllBrainInfo
timer_root: Optional[TimerNode]
class UnityEnvWorker:
def __init__(self, process: Process, worker_id: int, conn: Connection):
self.process = process
self.worker_id = worker_id
self.conn = conn
self.previous_step: EnvironmentStep = EnvironmentStep(None, {}, None)
self.previous_all_action_info: Dict[str, ActionInfo] = {}
self.waiting = False
def send(self, name: str, payload: Any = None) -> None:
try:
cmd = EnvironmentCommand(name, payload)
self.conn.send(cmd)
except (BrokenPipeError, EOFError):
raise UnityCommunicationException("UnityEnvironment worker: send failed.")
def recv(self) -> EnvironmentResponse:
try:
response: EnvironmentResponse = self.conn.recv()
return response
except (BrokenPipeError, EOFError):
raise UnityCommunicationException("UnityEnvironment worker: recv failed.")
def close(self):
try:
self.conn.send(EnvironmentCommand("close"))
except (BrokenPipeError, EOFError):
logger.debug(
f"UnityEnvWorker {self.worker_id} got exception trying to close."
)
pass
logger.debug(f"UnityEnvWorker {self.worker_id} joining process.")
self.process.join()
def worker(
parent_conn: Connection, step_queue: Queue, pickled_env_factory: str, worker_id: int
) -> None:
env_factory: Callable[[int], UnityEnvironment] = cloudpickle.loads(
pickled_env_factory
)
env = env_factory(worker_id)
def _send_response(cmd_name, payload):
parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload))
try:
while True:
cmd: EnvironmentCommand = parent_conn.recv()
if cmd.name == "step":
all_action_info = cmd.payload
actions = {}
memories = {}
texts = {}
values = {}
for brain_name, action_info in all_action_info.items():
actions[brain_name] = action_info.action
memories[brain_name] = action_info.memory
texts[brain_name] = action_info.text
values[brain_name] = action_info.value
all_brain_info = env.step(actions, memories, texts, values)
# The timers in this process are independent from all the processes and the "main" process
# So after we send back the root timer, we can safely clear them.
# Note that we could randomly return timers a fraction of the time if we wanted to reduce
# the data transferred.
# TODO get gauges from the workers and merge them in the main process too.
step_response = StepResponse(all_brain_info, get_timer_root())
step_queue.put(EnvironmentResponse("step", worker_id, step_response))
reset_timers()
elif cmd.name == "external_brains":
_send_response("external_brains", env.external_brains)
elif cmd.name == "reset_parameters":
_send_response("reset_parameters", env.reset_parameters)
elif cmd.name == "reset":
all_brain_info = env.reset(
cmd.payload[0], cmd.payload[1], cmd.payload[2]
)
_send_response("reset", all_brain_info)
elif cmd.name == "close":
break
except (KeyboardInterrupt, UnityCommunicationException):
print("UnityEnvironment worker: environment stopping.")
step_queue.put(EnvironmentResponse("env_close", worker_id, None))
finally:
# If this worker has put an item in the step queue that hasn't been processed by the EnvManager, the process
# will hang until the item is processed. We avoid this behavior by using Queue.cancel_join_thread()
# See https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue.cancel_join_thread for
# more info.
logger.debug(f"Worker {worker_id} closing.")
step_queue.cancel_join_thread()
step_queue.close()
env.close()
logger.debug(f"Worker {worker_id} done.")
class SubprocessEnvManager(EnvManager):
def __init__(
self, env_factory: Callable[[int], BaseUnityEnvironment], n_env: int = 1
):
super().__init__()
self.env_workers: List[UnityEnvWorker] = []
self.step_queue: Queue = Queue()
for worker_idx in range(n_env):
self.env_workers.append(
self.create_worker(worker_idx, self.step_queue, env_factory)
)
@staticmethod
def create_worker(
worker_id: int,
step_queue: Queue,
env_factory: Callable[[int], BaseUnityEnvironment],
) -> UnityEnvWorker:
parent_conn, child_conn = Pipe()
# Need to use cloudpickle for the env factory function since function objects aren't picklable
# on Windows as of Python 3.6.
pickled_env_factory = cloudpickle.dumps(env_factory)
child_process = Process(
target=worker, args=(child_conn, step_queue, pickled_env_factory, worker_id)
)
child_process.start()
return UnityEnvWorker(child_process, worker_id, parent_conn)
def _queue_steps(self) -> None:
for env_worker in self.env_workers:
if not env_worker.waiting:
env_action_info = self._take_step(env_worker.previous_step)
env_worker.previous_all_action_info = env_action_info
env_worker.send("step", env_action_info)
env_worker.waiting = True
def step(self) -> List[EnvironmentStep]:
# Queue steps for any workers which aren't in the "waiting" state.
self._queue_steps()
worker_steps: List[EnvironmentResponse] = []
step_workers: Set[int] = set()
# Poll the step queue for completed steps from environment workers until we retrieve
# 1 or more, which we will then return as StepInfos
while len(worker_steps) < 1:
try:
while True:
step = self.step_queue.get_nowait()
if step.name == "env_close":
raise UnityCommunicationException(
"At least one of the environments has closed."
)
self.env_workers[step.worker_id].waiting = False
if step.worker_id not in step_workers:
worker_steps.append(step)
step_workers.add(step.worker_id)
except EmptyQueueException:
pass
step_infos = self._postprocess_steps(worker_steps)
return step_infos
def reset(
self,
config: Optional[Dict] = None,
train_mode: bool = True,
custom_reset_parameters: Any = None,
) -> List[EnvironmentStep]:
while any([ew.waiting for ew in self.env_workers]):
if not self.step_queue.empty():
step = self.step_queue.get_nowait()
self.env_workers[step.worker_id].waiting = False
# First enqueue reset commands for all workers so that they reset in parallel
for ew in self.env_workers:
ew.send("reset", (config, train_mode, custom_reset_parameters))
# Next (synchronously) collect the reset observations from each worker in sequence
for ew in self.env_workers:
ew.previous_step = EnvironmentStep(None, ew.recv().payload, None)
return list(map(lambda ew: ew.previous_step, self.env_workers))
@property
def external_brains(self) -> Dict[str, BrainParameters]:
self.env_workers[0].send("external_brains")
return self.env_workers[0].recv().payload
@property
def reset_parameters(self) -> Dict[str, float]:
self.env_workers[0].send("reset_parameters")
return self.env_workers[0].recv().payload
def close(self) -> None:
logger.debug(f"SubprocessEnvManager closing.")
self.step_queue.close()
self.step_queue.join_thread()
for env_worker in self.env_workers:
env_worker.close()
def _postprocess_steps(
self, env_steps: List[EnvironmentResponse]
) -> List[EnvironmentStep]:
step_infos = []
timer_nodes = []
for step in env_steps:
payload: StepResponse = step.payload
env_worker = self.env_workers[step.worker_id]
new_step = EnvironmentStep(
env_worker.previous_step.current_all_brain_info,
payload.all_brain_info,
env_worker.previous_all_action_info,
)
step_infos.append(new_step)
env_worker.previous_step = new_step
if payload.timer_root:
timer_nodes.append(payload.timer_root)
if timer_nodes:
with hierarchical_timer("workers") as main_timer_node:
for worker_timer_node in timer_nodes:
main_timer_node.merge(
worker_timer_node, root_name="worker_root", is_parallel=True
)
return step_infos
@timed
def _take_step(self, last_step: EnvironmentStep) -> Dict[str, ActionInfo]:
all_action_info: Dict[str, ActionInfo] = {}
for brain_name, brain_info in last_step.current_all_brain_info.items():
if brain_name in self.policies:
all_action_info[brain_name] = self.policies[brain_name].get_action(
brain_info
)
return all_action_info