您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
187 行
6.8 KiB
187 行
6.8 KiB
from typing import *
|
|
import cloudpickle
|
|
|
|
from mlagents.envs import UnityEnvironment
|
|
from multiprocessing import Process, Pipe
|
|
from multiprocessing.connection import Connection
|
|
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
|
|
from mlagents.envs.env_manager import EnvManager, StepInfo
|
|
from mlagents.envs.timers import timed, hierarchical_timer
|
|
from mlagents.envs import AllBrainInfo, BrainParameters, ActionInfo
|
|
|
|
|
|
class EnvironmentCommand(NamedTuple):
|
|
name: str
|
|
payload: Any = None
|
|
|
|
|
|
class EnvironmentResponse(NamedTuple):
|
|
name: str
|
|
worker_id: int
|
|
payload: Any
|
|
|
|
|
|
class UnityEnvWorker:
|
|
def __init__(self, process: Process, worker_id: int, conn: Connection):
|
|
self.process = process
|
|
self.worker_id = worker_id
|
|
self.conn = conn
|
|
self.previous_step: StepInfo = StepInfo(None, {}, None)
|
|
self.previous_all_action_info: Dict[str, ActionInfo] = {}
|
|
|
|
def send(self, name: str, payload=None):
|
|
try:
|
|
cmd = EnvironmentCommand(name, payload)
|
|
self.conn.send(cmd)
|
|
except (BrokenPipeError, EOFError):
|
|
raise KeyboardInterrupt
|
|
|
|
def recv(self) -> EnvironmentResponse:
|
|
try:
|
|
response: EnvironmentResponse = self.conn.recv()
|
|
return response
|
|
except (BrokenPipeError, EOFError):
|
|
raise KeyboardInterrupt
|
|
|
|
def close(self):
|
|
try:
|
|
self.conn.send(EnvironmentCommand("close"))
|
|
except (BrokenPipeError, EOFError):
|
|
pass
|
|
self.process.join()
|
|
|
|
|
|
def worker(parent_conn: Connection, pickled_env_factory: str, worker_id: int):
|
|
env_factory: Callable[[int], UnityEnvironment] = cloudpickle.loads(
|
|
pickled_env_factory
|
|
)
|
|
env = env_factory(worker_id)
|
|
|
|
def _send_response(cmd_name, payload):
|
|
parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload))
|
|
|
|
try:
|
|
while True:
|
|
cmd: EnvironmentCommand = parent_conn.recv()
|
|
if cmd.name == "step":
|
|
all_action_info = cmd.payload
|
|
if env.global_done:
|
|
all_brain_info = env.reset()
|
|
else:
|
|
actions = {}
|
|
memories = {}
|
|
texts = {}
|
|
values = {}
|
|
for brain_name, action_info in all_action_info.items():
|
|
actions[brain_name] = action_info.action
|
|
memories[brain_name] = action_info.memory
|
|
texts[brain_name] = action_info.text
|
|
values[brain_name] = action_info.value
|
|
all_brain_info = env.step(actions, memories, texts, values)
|
|
_send_response("step", all_brain_info)
|
|
elif cmd.name == "external_brains":
|
|
_send_response("external_brains", env.external_brains)
|
|
elif cmd.name == "reset_parameters":
|
|
_send_response("reset_parameters", env.reset_parameters)
|
|
elif cmd.name == "reset":
|
|
all_brain_info = env.reset(
|
|
cmd.payload[0], cmd.payload[1], cmd.payload[2]
|
|
)
|
|
_send_response("reset", all_brain_info)
|
|
elif cmd.name == "global_done":
|
|
_send_response("global_done", env.global_done)
|
|
elif cmd.name == "close":
|
|
break
|
|
except KeyboardInterrupt:
|
|
print("UnityEnvironment worker: keyboard interrupt")
|
|
finally:
|
|
env.close()
|
|
|
|
|
|
class SubprocessEnvManager(EnvManager):
|
|
def __init__(
|
|
self, env_factory: Callable[[int], BaseUnityEnvironment], n_env: int = 1
|
|
):
|
|
super().__init__()
|
|
self.env_workers: List[UnityEnvWorker] = []
|
|
for worker_idx in range(n_env):
|
|
self.env_workers.append(self.create_worker(worker_idx, env_factory))
|
|
|
|
def get_last_steps(self):
|
|
return [ew.previous_step for ew in self.env_workers]
|
|
|
|
@staticmethod
|
|
def create_worker(
|
|
worker_id: int, env_factory: Callable[[int], BaseUnityEnvironment]
|
|
) -> UnityEnvWorker:
|
|
parent_conn, child_conn = Pipe()
|
|
|
|
# Need to use cloudpickle for the env factory function since function objects aren't picklable
|
|
# on Windows as of Python 3.6.
|
|
pickled_env_factory = cloudpickle.dumps(env_factory)
|
|
child_process = Process(
|
|
target=worker, args=(child_conn, pickled_env_factory, worker_id)
|
|
)
|
|
child_process.start()
|
|
return UnityEnvWorker(child_process, worker_id, parent_conn)
|
|
|
|
def step(self) -> List[StepInfo]:
|
|
for env_worker in self.env_workers:
|
|
all_action_info = self._take_step(env_worker.previous_step)
|
|
env_worker.previous_all_action_info = all_action_info
|
|
env_worker.send("step", all_action_info)
|
|
|
|
with hierarchical_timer("recv"):
|
|
step_brain_infos: List[AllBrainInfo] = [
|
|
self.env_workers[i].recv().payload for i in range(len(self.env_workers))
|
|
]
|
|
steps = []
|
|
for i in range(len(step_brain_infos)):
|
|
env_worker = self.env_workers[i]
|
|
step_info = StepInfo(
|
|
env_worker.previous_step.current_all_brain_info,
|
|
step_brain_infos[i],
|
|
env_worker.previous_all_action_info,
|
|
)
|
|
env_worker.previous_step = step_info
|
|
steps.append(step_info)
|
|
return steps
|
|
|
|
def reset(
|
|
self, config=None, train_mode=True, custom_reset_parameters=None
|
|
) -> List[StepInfo]:
|
|
self._broadcast_message("reset", (config, train_mode, custom_reset_parameters))
|
|
reset_results = [
|
|
self.env_workers[i].recv().payload for i in range(len(self.env_workers))
|
|
]
|
|
for i in range(len(reset_results)):
|
|
env_worker = self.env_workers[i]
|
|
env_worker.previous_step = StepInfo(None, reset_results[i], None)
|
|
return list(map(lambda ew: ew.previous_step, self.env_workers))
|
|
|
|
@property
|
|
def external_brains(self) -> Dict[str, BrainParameters]:
|
|
self.env_workers[0].send("external_brains")
|
|
return self.env_workers[0].recv().payload
|
|
|
|
@property
|
|
def reset_parameters(self) -> Dict[str, float]:
|
|
self.env_workers[0].send("reset_parameters")
|
|
return self.env_workers[0].recv().payload
|
|
|
|
def close(self):
|
|
for env in self.env_workers:
|
|
env.close()
|
|
|
|
def _broadcast_message(self, name: str, payload=None):
|
|
for env in self.env_workers:
|
|
env.send(name, payload)
|
|
|
|
@timed
|
|
def _take_step(self, last_step: StepInfo) -> Dict[str, ActionInfo]:
|
|
all_action_info: Dict[str, ActionInfo] = {}
|
|
for brain_name, brain_info in last_step.current_all_brain_info.items():
|
|
all_action_info[brain_name] = self.policies[brain_name].get_action(
|
|
brain_info
|
|
)
|
|
return all_action_info
|