浏览代码
Adds SubprocessUnityEnvironment for parallel envs (#1751)
Adds SubprocessUnityEnvironment for parallel envs (#1751)
This commit adds support for running Unity environments in parallel. An abstract base class was created for UnityEnvironment which a new SubprocessUnityEnvironment inherits from. SubprocessUnityEnvironment communicates through a pipe in order to send commands which will be run in parallel to its workers. A few significant changes needed to be made as a side-effect: * UnityEnvironments are created via a factory method (a closure) rather than being directly created by the main process. * In mlagents-learn "worker-id" has been replaced by "base-port" and "num-envs", and worker_ids are automatically assigned across runs. * BrainInfo objects now convert all fields to numpy arrays or lists to avoid serialization issues./develop-generalizationTraining-TrainerController
GitHub
6 年前
当前提交
93760bc4
共有 13 个文件被更改,包括 518 次插入 和 64 次删除
-
2ml-agents-envs/mlagents/envs/__init__.py
-
96ml-agents-envs/mlagents/envs/brain.py
-
43ml-agents-envs/mlagents/envs/environment.py
-
9ml-agents-envs/mlagents/envs/tests/test_envs.py
-
65ml-agents/mlagents/trainers/learn.py
-
6ml-agents/mlagents/trainers/ppo/trainer.py
-
19ml-agents/mlagents/trainers/tests/test_learn.py
-
27ml-agents/mlagents/trainers/trainer_controller.py
-
33ml-agents-envs/mlagents/envs/base_unity_environment.py
-
192ml-agents-envs/mlagents/envs/subprocess_environment.py
-
90ml-agents/tests/envs/test_subprocess_unity_environment.py
-
0ml-agents-envs/__init__.py
|
|||
from .environment import * |
|||
from .environment import * |
|||
from .exception import * |
|
|||
from abc import ABC, abstractmethod |
|||
from typing import Dict |
|||
|
|||
from mlagents.envs import AllBrainInfo, BrainParameters |
|||
|
|||
|
|||
class BaseUnityEnvironment(ABC): |
|||
@abstractmethod |
|||
def step(self, vector_action=None, memory=None, text_action=None, value=None) -> AllBrainInfo: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def reset(self, config=None, train_mode=True) -> AllBrainInfo: |
|||
pass |
|||
|
|||
@property |
|||
@abstractmethod |
|||
def global_done(self): |
|||
pass |
|||
|
|||
@property |
|||
@abstractmethod |
|||
def external_brains(self) -> Dict[str, BrainParameters]: |
|||
pass |
|||
|
|||
@property |
|||
@abstractmethod |
|||
def reset_parameters(self) -> Dict[str, str]: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def close(self): |
|||
pass |
|
|||
from typing import * |
|||
import copy |
|||
import numpy as np |
|||
|
|||
from mlagents.envs import UnityEnvironment |
|||
from multiprocessing import Process, Pipe |
|||
from multiprocessing.connection import Connection |
|||
from mlagents.envs.base_unity_environment import BaseUnityEnvironment |
|||
from mlagents.envs import AllBrainInfo, UnityEnvironmentException |
|||
|
|||
|
|||
class EnvironmentCommand(NamedTuple): |
|||
name: str |
|||
payload: Any = None |
|||
|
|||
|
|||
class EnvironmentResponse(NamedTuple): |
|||
name: str |
|||
worker_id: int |
|||
payload: Any |
|||
|
|||
|
|||
class UnityEnvWorker(NamedTuple): |
|||
process: Process |
|||
worker_id: int |
|||
conn: Connection |
|||
|
|||
def send(self, name: str, payload=None): |
|||
cmd = EnvironmentCommand(name, payload) |
|||
self.conn.send(cmd) |
|||
|
|||
def recv(self) -> EnvironmentResponse: |
|||
response: EnvironmentResponse = self.conn.recv() |
|||
return response |
|||
|
|||
def close(self): |
|||
self.process.join() |
|||
|
|||
|
|||
def worker(parent_conn: Connection, env_factory: Callable[[int], UnityEnvironment], worker_id: int): |
|||
env = env_factory(worker_id) |
|||
|
|||
def _send_response(cmd_name, payload): |
|||
parent_conn.send( |
|||
EnvironmentResponse(cmd_name, worker_id, payload) |
|||
) |
|||
try: |
|||
while True: |
|||
cmd: EnvironmentCommand = parent_conn.recv() |
|||
if cmd.name == 'step': |
|||
vector_action, memory, text_action, value = cmd.payload |
|||
all_brain_info = env.step(vector_action, memory, text_action, value) |
|||
_send_response('step', all_brain_info) |
|||
elif cmd.name == 'external_brains': |
|||
_send_response('external_brains', env.external_brains) |
|||
elif cmd.name == 'reset_parameters': |
|||
_send_response('reset_parameters', env.reset_parameters) |
|||
elif cmd.name == 'reset': |
|||
all_brain_info = env.reset(cmd.payload[0], cmd.payload[1]) |
|||
_send_response('reset', all_brain_info) |
|||
elif cmd.name == 'global_done': |
|||
_send_response('global_done', env.global_done) |
|||
elif cmd.name == 'close': |
|||
env.close() |
|||
break |
|||
except KeyboardInterrupt: |
|||
print('UnityEnvironment worker: keyboard interrupt') |
|||
finally: |
|||
env.close() |
|||
|
|||
|
|||
class SubprocessUnityEnvironment(BaseUnityEnvironment): |
|||
def __init__(self, |
|||
env_factory: Callable[[int], BaseUnityEnvironment], |
|||
n_env: int = 1): |
|||
self.envs = [] |
|||
self.env_agent_counts = {} |
|||
self.waiting = False |
|||
for worker_id in range(n_env): |
|||
self.envs.append(self.create_worker(worker_id, env_factory)) |
|||
|
|||
@staticmethod |
|||
def create_worker( |
|||
worker_id: int, |
|||
env_factory: Callable[[int], BaseUnityEnvironment] |
|||
) -> UnityEnvWorker: |
|||
parent_conn, child_conn = Pipe() |
|||
child_process = Process(target=worker, args=(child_conn, env_factory, worker_id)) |
|||
child_process.start() |
|||
return UnityEnvWorker(child_process, worker_id, parent_conn) |
|||
|
|||
def step_async(self, vector_action, memory=None, text_action=None, value=None) -> None: |
|||
if self.waiting: |
|||
raise UnityEnvironmentException( |
|||
'Tried to take an environment step bore previous step has completed.' |
|||
) |
|||
|
|||
agent_counts_cum = {} |
|||
for brain_name in self.env_agent_counts.keys(): |
|||
agent_counts_cum[brain_name] = np.cumsum(self.env_agent_counts[brain_name]) |
|||
|
|||
# Split the actions provided by the previous set of agent counts, and send the step |
|||
# commands to the workers. |
|||
for worker_id, env in enumerate(self.envs): |
|||
env_actions = {} |
|||
env_memory = {} |
|||
env_text_action = {} |
|||
env_value = {} |
|||
for brain_name in self.env_agent_counts.keys(): |
|||
start_ind = 0 |
|||
if worker_id > 0: |
|||
start_ind = agent_counts_cum[brain_name][worker_id - 1] |
|||
end_ind = agent_counts_cum[brain_name][worker_id] |
|||
if vector_action.get(brain_name) is not None: |
|||
env_actions[brain_name] = vector_action[brain_name][start_ind:end_ind] |
|||
if memory and memory.get(brain_name) is not None: |
|||
env_memory[brain_name] = memory[brain_name][start_ind:end_ind] |
|||
if text_action and text_action.get(brain_name) is not None: |
|||
env_text_action[brain_name] = text_action[brain_name][start_ind:end_ind] |
|||
if value and value.get(brain_name) is not None: |
|||
env_value[brain_name] = value[brain_name][start_ind:end_ind] |
|||
|
|||
env.send('step', (env_actions, env_memory, env_text_action, env_value)) |
|||
self.waiting = True |
|||
|
|||
def step_await(self) -> AllBrainInfo: |
|||
if not self.waiting: |
|||
raise UnityEnvironmentException('Tried to await an environment step, but no async step was taken.') |
|||
|
|||
steps = [self.envs[i].recv() for i in range(len(self.envs))] |
|||
self._get_agent_counts(map(lambda s: s.payload, steps)) |
|||
combined_brain_info = self._merge_step_info(steps) |
|||
self.waiting = False |
|||
return combined_brain_info |
|||
|
|||
def step(self, vector_action=None, memory=None, text_action=None, value=None) -> AllBrainInfo: |
|||
self.step_async(vector_action, memory, text_action, value) |
|||
return self.step_await() |
|||
|
|||
def reset(self, config=None, train_mode=True) -> AllBrainInfo: |
|||
self._broadcast_message('reset', (config, train_mode)) |
|||
reset_results = [self.envs[i].recv() for i in range(len(self.envs))] |
|||
self._get_agent_counts(map(lambda r: r.payload, reset_results)) |
|||
|
|||
return self._merge_step_info(reset_results) |
|||
|
|||
@property |
|||
def global_done(self): |
|||
self._broadcast_message('global_done') |
|||
dones: List[EnvironmentResponse] = [ |
|||
self.envs[i].recv().payload for i in range(len(self.envs)) |
|||
] |
|||
return all(dones) |
|||
|
|||
@property |
|||
def external_brains(self): |
|||
self.envs[0].send('external_brains') |
|||
return self.envs[0].recv().payload |
|||
|
|||
@property |
|||
def reset_parameters(self): |
|||
self.envs[0].send('reset_parameters') |
|||
return self.envs[0].recv().payload |
|||
|
|||
def close(self): |
|||
for env in self.envs: |
|||
env.close() |
|||
|
|||
def _get_agent_counts(self, step_list: Iterable[AllBrainInfo]): |
|||
for i, step in enumerate(step_list): |
|||
for brain_name, brain_info in step.items(): |
|||
if brain_name not in self.env_agent_counts.keys(): |
|||
self.env_agent_counts[brain_name] = [0] * len(self.envs) |
|||
self.env_agent_counts[brain_name][i] = len(brain_info.agents) |
|||
|
|||
@staticmethod |
|||
def _merge_step_info(env_steps: List[EnvironmentResponse]) -> AllBrainInfo: |
|||
accumulated_brain_info: AllBrainInfo = None |
|||
for env_step in env_steps: |
|||
all_brain_info: AllBrainInfo = env_step.payload |
|||
for brain_name, brain_info in all_brain_info.items(): |
|||
for i in range(len(brain_info.agents)): |
|||
brain_info.agents[i] = str(env_step.worker_id) + '-' + str(brain_info.agents[i]) |
|||
if accumulated_brain_info: |
|||
accumulated_brain_info[brain_name].merge(brain_info) |
|||
if not accumulated_brain_info: |
|||
accumulated_brain_info = copy.deepcopy(all_brain_info) |
|||
return accumulated_brain_info |
|||
|
|||
def _broadcast_message(self, name: str, payload = None): |
|||
for env in self.envs: |
|||
env.send(name, payload) |
|
|||
import unittest.mock as mock |
|||
from unittest.mock import MagicMock |
|||
import unittest |
|||
|
|||
from mlagents.envs.subprocess_environment import * |
|||
from mlagents.envs import UnityEnvironmentException, BrainInfo |
|||
|
|||
|
|||
def mock_env_factory(worker_id: int): |
|||
return mock.create_autospec(spec=BaseUnityEnvironment) |
|||
|
|||
|
|||
class MockEnvWorker: |
|||
def __init__(self, worker_id): |
|||
self.worker_id = worker_id |
|||
self.process = None |
|||
self.conn = None |
|||
self.send = MagicMock() |
|||
self.recv = MagicMock() |
|||
|
|||
|
|||
class SubprocessEnvironmentTest(unittest.TestCase): |
|||
def test_environments_are_created(self): |
|||
SubprocessUnityEnvironment.create_worker = MagicMock() |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 2) |
|||
# Creates two processes |
|||
self.assertEqual(env.create_worker.call_args_list, [ |
|||
mock.call(0, mock_env_factory), |
|||
mock.call(1, mock_env_factory) |
|||
]) |
|||
self.assertEqual(len(env.envs), 2) |
|||
|
|||
def test_step_async_fails_when_waiting(self): |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 0) |
|||
env.waiting = True |
|||
with self.assertRaises(UnityEnvironmentException): |
|||
env.step_async(vector_action=[]) |
|||
|
|||
@staticmethod |
|||
def test_step_async_splits_input_by_agent_count(): |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 0) |
|||
env.env_agent_counts = { |
|||
'MockBrain': [1, 3, 5] |
|||
} |
|||
env.envs = [ |
|||
MockEnvWorker(0), |
|||
MockEnvWorker(1), |
|||
MockEnvWorker(2), |
|||
] |
|||
env_0_actions = [[1.0, 2.0]] |
|||
env_1_actions = ([[3.0, 4.0]] * 3) |
|||
env_2_actions = ([[5.0, 6.0]] * 5) |
|||
vector_action = { |
|||
'MockBrain': env_0_actions + env_1_actions + env_2_actions |
|||
} |
|||
env.step_async(vector_action=vector_action) |
|||
env.envs[0].send.assert_called_with('step', ({'MockBrain': env_0_actions}, {}, {}, {})) |
|||
env.envs[1].send.assert_called_with('step', ({'MockBrain': env_1_actions}, {}, {}, {})) |
|||
env.envs[2].send.assert_called_with('step', ({'MockBrain': env_2_actions}, {}, {}, {})) |
|||
|
|||
def test_step_async_sets_waiting(self): |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 0) |
|||
env.step_async(vector_action=[]) |
|||
self.assertTrue(env.waiting) |
|||
|
|||
def test_step_await_fails_if_not_waiting(self): |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 0) |
|||
with self.assertRaises(UnityEnvironmentException): |
|||
env.step_await() |
|||
|
|||
def test_step_await_combines_brain_info(self): |
|||
all_brain_info_env0 = { |
|||
'MockBrain': BrainInfo([], [[1.0, 2.0], [1.0, 2.0]], [], agents=[1, 2], memory=np.zeros((0,0))) |
|||
} |
|||
all_brain_info_env1 = { |
|||
'MockBrain': BrainInfo([], [[3.0, 4.0]], [], agents=[3], memory=np.zeros((0,0))) |
|||
} |
|||
env_worker_0 = MockEnvWorker(0) |
|||
env_worker_0.recv.return_value = EnvironmentResponse('step', 0, all_brain_info_env0) |
|||
env_worker_1 = MockEnvWorker(1) |
|||
env_worker_1.recv.return_value = EnvironmentResponse('step', 1, all_brain_info_env1) |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 0) |
|||
env.envs = [env_worker_0, env_worker_1] |
|||
env.waiting = True |
|||
combined_braininfo = env.step_await()['MockBrain'] |
|||
self.assertEqual( |
|||
combined_braininfo.vector_observations.tolist(), |
|||
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0]] |
|||
) |
|||
self.assertEqual(combined_braininfo.agents, ['0-1', '0-2', '1-3']) |
撰写
预览
正在加载...
取消
保存
Reference in new issue