主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
207 行
8.0 KiB
207 行
8.0 KiB
from typing import *
import copy
import numpy as np
import cloudpickle
from mlagents.envs import UnityEnvironment
from multiprocessing import Process, Pipe
from multiprocessing.connection import Connection
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs import AllBrainInfo, UnityEnvironmentException
class EnvironmentCommand(NamedTuple):
name: str
payload: Any = None
class EnvironmentResponse(NamedTuple):
name: str
worker_id: int
payload: Any
class UnityEnvWorker(NamedTuple):
process: Process
worker_id: int
conn: Connection
def send(self, name: str, payload=None):
cmd = EnvironmentCommand(name, payload)
except (BrokenPipeError, EOFError):
raise KeyboardInterrupt
def recv(self) -> EnvironmentResponse:
response: EnvironmentResponse = self.conn.recv()
return response
except (BrokenPipeError, EOFError):
raise KeyboardInterrupt
def close(self):
except (BrokenPipeError, EOFError):
def worker(parent_conn: Connection, pickled_env_factory: str, worker_id: int):
env_factory: Callable[[int], UnityEnvironment] = cloudpickle.loads(pickled_env_factory)
env = env_factory(worker_id)
def _send_response(cmd_name, payload):
EnvironmentResponse(cmd_name, worker_id, payload)
while True:
cmd: EnvironmentCommand = parent_conn.recv()
if cmd.name == 'step':
vector_action, memory, text_action, value = cmd.payload
all_brain_info = env.step(vector_action, memory, text_action, value)
_send_response('step', all_brain_info)
elif cmd.name == 'external_brains':
_send_response('external_brains', env.external_brains)
elif cmd.name == 'reset_parameters':
_send_response('reset_parameters', env.reset_parameters)
elif cmd.name == 'reset':
all_brain_info = env.reset(cmd.payload[0], cmd.payload[1])
_send_response('reset', all_brain_info)
elif cmd.name == 'global_done':
_send_response('global_done', env.global_done)
elif cmd.name == 'close':
except KeyboardInterrupt:
print('UnityEnvironment worker: keyboard interrupt')
class SubprocessUnityEnvironment(BaseUnityEnvironment):
def __init__(self,
env_factory: Callable[[int], BaseUnityEnvironment],
n_env: int = 1):
self.envs = []
self.env_agent_counts = {}
self.waiting = False
for worker_id in range(n_env):
self.envs.append(self.create_worker(worker_id, env_factory))
def create_worker(
worker_id: int,
env_factory: Callable[[int], BaseUnityEnvironment]
) -> UnityEnvWorker:
parent_conn, child_conn = Pipe()
# Need to use cloudpickle for the env factory function since function objects aren't picklable
# on Windows as of Python 3.6.
pickled_env_factory = cloudpickle.dumps(env_factory)
child_process = Process(target=worker, args=(child_conn, pickled_env_factory, worker_id))
return UnityEnvWorker(child_process, worker_id, parent_conn)
def step_async(self, vector_action, memory=None, text_action=None, value=None) -> None:
if self.waiting:
raise UnityEnvironmentException(
'Tried to take an environment step bore previous step has completed.'
agent_counts_cum = {}
for brain_name in self.env_agent_counts.keys():
agent_counts_cum[brain_name] = np.cumsum(self.env_agent_counts[brain_name])
# Split the actions provided by the previous set of agent counts, and send the step
# commands to the workers.
for worker_id, env in enumerate(self.envs):
env_actions = {}
env_memory = {}
env_text_action = {}
env_value = {}
for brain_name in self.env_agent_counts.keys():
start_ind = 0
if worker_id > 0:
start_ind = agent_counts_cum[brain_name][worker_id - 1]
end_ind = agent_counts_cum[brain_name][worker_id]
if vector_action.get(brain_name) is not None:
env_actions[brain_name] = vector_action[brain_name][start_ind:end_ind]
if memory and memory.get(brain_name) is not None:
env_memory[brain_name] = memory[brain_name][start_ind:end_ind]
if text_action and text_action.get(brain_name) is not None:
env_text_action[brain_name] = text_action[brain_name][start_ind:end_ind]
if value and value.get(brain_name) is not None:
env_value[brain_name] = value[brain_name][start_ind:end_ind]
env.send('step', (env_actions, env_memory, env_text_action, env_value))
self.waiting = True
def step_await(self) -> AllBrainInfo:
if not self.waiting:
raise UnityEnvironmentException('Tried to await an environment step, but no async step was taken.')
steps = [self.envs[i].recv() for i in range(len(self.envs))]
self._get_agent_counts(map(lambda s: s.payload, steps))
combined_brain_info = self._merge_step_info(steps)
self.waiting = False
return combined_brain_info
def step(self, vector_action=None, memory=None, text_action=None, value=None) -> AllBrainInfo:
self.step_async(vector_action, memory, text_action, value)
return self.step_await()
def reset(self, config=None, train_mode=True) -> AllBrainInfo:
self._broadcast_message('reset', (config, train_mode))
reset_results = [self.envs[i].recv() for i in range(len(self.envs))]
self._get_agent_counts(map(lambda r: r.payload, reset_results))
return self._merge_step_info(reset_results)
def global_done(self):
dones: List[EnvironmentResponse] = [
self.envs[i].recv().payload for i in range(len(self.envs))
return all(dones)
def external_brains(self):
return self.envs[0].recv().payload
def reset_parameters(self):
return self.envs[0].recv().payload
def close(self):
for env in self.envs:
def _get_agent_counts(self, step_list: Iterable[AllBrainInfo]):
for i, step in enumerate(step_list):
for brain_name, brain_info in step.items():
if brain_name not in self.env_agent_counts.keys():
self.env_agent_counts[brain_name] = [0] * len(self.envs)
self.env_agent_counts[brain_name][i] = len(brain_info.agents)
def _merge_step_info(env_steps: List[EnvironmentResponse]) -> AllBrainInfo:
accumulated_brain_info: AllBrainInfo = None
for env_step in env_steps:
all_brain_info: AllBrainInfo = env_step.payload
for brain_name, brain_info in all_brain_info.items():
for i in range(len(brain_info.agents)):
brain_info.agents[i] = str(env_step.worker_id) + '-' + str(brain_info.agents[i])
if accumulated_brain_info:
if not accumulated_brain_info:
accumulated_brain_info = copy.deepcopy(all_brain_info)
return accumulated_brain_info
def _broadcast_message(self, name: str, payload = None):
for env in self.envs:
env.send(name, payload)