您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
293 行
12 KiB
293 行
12 KiB
import logging
|
|
from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set
|
|
import cloudpickle
|
|
|
|
from mlagents.envs.environment import UnityEnvironment
|
|
from mlagents.envs.exception import UnityCommunicationException, UnityTimeOutException
|
|
from multiprocessing import Process, Pipe, Queue
|
|
from multiprocessing.connection import Connection
|
|
from queue import Empty as EmptyQueueException
|
|
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
|
|
from mlagents.envs.env_manager import EnvManager, EnvironmentStep
|
|
from mlagents.envs.timers import (
|
|
TimerNode,
|
|
timed,
|
|
hierarchical_timer,
|
|
reset_timers,
|
|
get_timer_root,
|
|
)
|
|
from mlagents.envs.brain import AllBrainInfo, BrainParameters
|
|
from mlagents.envs.action_info import ActionInfo
|
|
from mlagents.envs.side_channel.float_properties_channel import FloatPropertiesChannel
|
|
from mlagents.envs.side_channel.engine_configuration_channel import (
|
|
EngineConfigurationChannel,
|
|
EngineConfig,
|
|
)
|
|
from mlagents.envs.side_channel.side_channel import SideChannel
|
|
|
|
logger = logging.getLogger("mlagents.envs")
|
|
|
|
|
|
class EnvironmentCommand(NamedTuple):
|
|
name: str
|
|
payload: Any = None
|
|
|
|
|
|
class EnvironmentResponse(NamedTuple):
|
|
name: str
|
|
worker_id: int
|
|
payload: Any
|
|
|
|
|
|
class StepResponse(NamedTuple):
|
|
all_brain_info: AllBrainInfo
|
|
timer_root: Optional[TimerNode]
|
|
|
|
|
|
class UnityEnvWorker:
|
|
def __init__(self, process: Process, worker_id: int, conn: Connection):
|
|
self.process = process
|
|
self.worker_id = worker_id
|
|
self.conn = conn
|
|
self.previous_step: EnvironmentStep = EnvironmentStep(None, {}, None)
|
|
self.previous_all_action_info: Dict[str, ActionInfo] = {}
|
|
self.waiting = False
|
|
|
|
def send(self, name: str, payload: Any = None) -> None:
|
|
try:
|
|
cmd = EnvironmentCommand(name, payload)
|
|
self.conn.send(cmd)
|
|
except (BrokenPipeError, EOFError):
|
|
raise UnityCommunicationException("UnityEnvironment worker: send failed.")
|
|
|
|
def recv(self) -> EnvironmentResponse:
|
|
try:
|
|
response: EnvironmentResponse = self.conn.recv()
|
|
return response
|
|
except (BrokenPipeError, EOFError):
|
|
raise UnityCommunicationException("UnityEnvironment worker: recv failed.")
|
|
|
|
def close(self):
|
|
try:
|
|
self.conn.send(EnvironmentCommand("close"))
|
|
except (BrokenPipeError, EOFError):
|
|
logger.debug(
|
|
f"UnityEnvWorker {self.worker_id} got exception trying to close."
|
|
)
|
|
pass
|
|
logger.debug(f"UnityEnvWorker {self.worker_id} joining process.")
|
|
self.process.join()
|
|
|
|
|
|
def worker(
|
|
parent_conn: Connection,
|
|
step_queue: Queue,
|
|
pickled_env_factory: str,
|
|
worker_id: int,
|
|
engine_configuration: EngineConfig,
|
|
) -> None:
|
|
env_factory: Callable[
|
|
[int, List[SideChannel]], UnityEnvironment
|
|
] = cloudpickle.loads(pickled_env_factory)
|
|
shared_float_properties = FloatPropertiesChannel()
|
|
engine_configuration_channel = EngineConfigurationChannel()
|
|
engine_configuration_channel.set_configuration(engine_configuration)
|
|
env: BaseUnityEnvironment = env_factory(
|
|
worker_id, [shared_float_properties, engine_configuration_channel]
|
|
)
|
|
|
|
def _send_response(cmd_name, payload):
|
|
parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload))
|
|
|
|
try:
|
|
while True:
|
|
cmd: EnvironmentCommand = parent_conn.recv()
|
|
if cmd.name == "step":
|
|
all_action_info = cmd.payload
|
|
actions = {}
|
|
values = {}
|
|
for brain_name, action_info in all_action_info.items():
|
|
actions[brain_name] = action_info.action
|
|
values[brain_name] = action_info.value
|
|
all_brain_info = env.step(vector_action=actions, value=values)
|
|
# The timers in this process are independent from all the processes and the "main" process
|
|
# So after we send back the root timer, we can safely clear them.
|
|
# Note that we could randomly return timers a fraction of the time if we wanted to reduce
|
|
# the data transferred.
|
|
# TODO get gauges from the workers and merge them in the main process too.
|
|
step_response = StepResponse(all_brain_info, get_timer_root())
|
|
step_queue.put(EnvironmentResponse("step", worker_id, step_response))
|
|
reset_timers()
|
|
elif cmd.name == "external_brains":
|
|
_send_response("external_brains", env.external_brains)
|
|
elif cmd.name == "get_properties":
|
|
reset_params = {}
|
|
for k in shared_float_properties.list_properties():
|
|
reset_params[k] = shared_float_properties.get_property(k)
|
|
|
|
_send_response("get_properties", reset_params)
|
|
elif cmd.name == "reset":
|
|
for k, v in cmd.payload.items():
|
|
shared_float_properties.set_property(k, v)
|
|
all_brain_info = env.reset()
|
|
_send_response("reset", all_brain_info)
|
|
elif cmd.name == "close":
|
|
break
|
|
except (KeyboardInterrupt, UnityCommunicationException, UnityTimeOutException):
|
|
logger.info(f"UnityEnvironment worker {worker_id}: environment stopping.")
|
|
step_queue.put(EnvironmentResponse("env_close", worker_id, None))
|
|
finally:
|
|
# If this worker has put an item in the step queue that hasn't been processed by the EnvManager, the process
|
|
# will hang until the item is processed. We avoid this behavior by using Queue.cancel_join_thread()
|
|
# See https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue.cancel_join_thread for
|
|
# more info.
|
|
logger.debug(f"UnityEnvironment worker {worker_id} closing.")
|
|
step_queue.cancel_join_thread()
|
|
step_queue.close()
|
|
env.close()
|
|
logger.debug(f"UnityEnvironment worker {worker_id} done.")
|
|
|
|
|
|
class SubprocessEnvManager(EnvManager):
|
|
def __init__(
|
|
self,
|
|
env_factory: Callable[[int, List[SideChannel]], BaseUnityEnvironment],
|
|
engine_configuration: EngineConfig,
|
|
n_env: int = 1,
|
|
):
|
|
super().__init__()
|
|
self.env_workers: List[UnityEnvWorker] = []
|
|
self.step_queue: Queue = Queue()
|
|
for worker_idx in range(n_env):
|
|
self.env_workers.append(
|
|
self.create_worker(
|
|
worker_idx, self.step_queue, env_factory, engine_configuration
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def create_worker(
|
|
worker_id: int,
|
|
step_queue: Queue,
|
|
env_factory: Callable[[int, List[SideChannel]], BaseUnityEnvironment],
|
|
engine_configuration: EngineConfig,
|
|
) -> UnityEnvWorker:
|
|
parent_conn, child_conn = Pipe()
|
|
|
|
# Need to use cloudpickle for the env factory function since function objects aren't picklable
|
|
# on Windows as of Python 3.6.
|
|
pickled_env_factory = cloudpickle.dumps(env_factory)
|
|
child_process = Process(
|
|
target=worker,
|
|
args=(
|
|
child_conn,
|
|
step_queue,
|
|
pickled_env_factory,
|
|
worker_id,
|
|
engine_configuration,
|
|
),
|
|
)
|
|
child_process.start()
|
|
return UnityEnvWorker(child_process, worker_id, parent_conn)
|
|
|
|
def _queue_steps(self) -> None:
|
|
for env_worker in self.env_workers:
|
|
if not env_worker.waiting:
|
|
env_action_info = self._take_step(env_worker.previous_step)
|
|
env_worker.previous_all_action_info = env_action_info
|
|
env_worker.send("step", env_action_info)
|
|
env_worker.waiting = True
|
|
|
|
def step(self) -> List[EnvironmentStep]:
|
|
# Queue steps for any workers which aren't in the "waiting" state.
|
|
self._queue_steps()
|
|
|
|
worker_steps: List[EnvironmentResponse] = []
|
|
step_workers: Set[int] = set()
|
|
# Poll the step queue for completed steps from environment workers until we retrieve
|
|
# 1 or more, which we will then return as StepInfos
|
|
while len(worker_steps) < 1:
|
|
try:
|
|
while True:
|
|
step = self.step_queue.get_nowait()
|
|
if step.name == "env_close":
|
|
raise UnityCommunicationException(
|
|
"At least one of the environments has closed."
|
|
)
|
|
self.env_workers[step.worker_id].waiting = False
|
|
if step.worker_id not in step_workers:
|
|
worker_steps.append(step)
|
|
step_workers.add(step.worker_id)
|
|
except EmptyQueueException:
|
|
pass
|
|
|
|
step_infos = self._postprocess_steps(worker_steps)
|
|
return step_infos
|
|
|
|
def reset(self, config: Optional[Dict] = None) -> List[EnvironmentStep]:
|
|
while any(ew.waiting for ew in self.env_workers):
|
|
if not self.step_queue.empty():
|
|
step = self.step_queue.get_nowait()
|
|
self.env_workers[step.worker_id].waiting = False
|
|
# First enqueue reset commands for all workers so that they reset in parallel
|
|
for ew in self.env_workers:
|
|
ew.send("reset", config)
|
|
# Next (synchronously) collect the reset observations from each worker in sequence
|
|
for ew in self.env_workers:
|
|
ew.previous_step = EnvironmentStep(None, ew.recv().payload, None)
|
|
return list(map(lambda ew: ew.previous_step, self.env_workers))
|
|
|
|
@property
|
|
def external_brains(self) -> Dict[str, BrainParameters]:
|
|
self.env_workers[0].send("external_brains")
|
|
return self.env_workers[0].recv().payload
|
|
|
|
@property
|
|
def get_properties(self) -> Dict[str, float]:
|
|
self.env_workers[0].send("get_properties")
|
|
return self.env_workers[0].recv().payload
|
|
|
|
def close(self) -> None:
|
|
logger.debug(f"SubprocessEnvManager closing.")
|
|
self.step_queue.close()
|
|
self.step_queue.join_thread()
|
|
for env_worker in self.env_workers:
|
|
env_worker.close()
|
|
|
|
def _postprocess_steps(
|
|
self, env_steps: List[EnvironmentResponse]
|
|
) -> List[EnvironmentStep]:
|
|
step_infos = []
|
|
timer_nodes = []
|
|
for step in env_steps:
|
|
payload: StepResponse = step.payload
|
|
env_worker = self.env_workers[step.worker_id]
|
|
new_step = EnvironmentStep(
|
|
env_worker.previous_step.current_all_brain_info,
|
|
payload.all_brain_info,
|
|
env_worker.previous_all_action_info,
|
|
)
|
|
step_infos.append(new_step)
|
|
env_worker.previous_step = new_step
|
|
|
|
if payload.timer_root:
|
|
timer_nodes.append(payload.timer_root)
|
|
|
|
if timer_nodes:
|
|
with hierarchical_timer("workers") as main_timer_node:
|
|
for worker_timer_node in timer_nodes:
|
|
main_timer_node.merge(
|
|
worker_timer_node, root_name="worker_root", is_parallel=True
|
|
)
|
|
|
|
return step_infos
|
|
|
|
@timed
|
|
def _take_step(self, last_step: EnvironmentStep) -> Dict[str, ActionInfo]:
|
|
all_action_info: Dict[str, ActionInfo] = {}
|
|
for brain_name, brain_info in last_step.current_all_brain_info.items():
|
|
if brain_name in self.policies:
|
|
all_action_info[brain_name] = self.policies[brain_name].get_action(
|
|
brain_info
|
|
)
|
|
return all_action_info
|