浏览代码

Make SubprocessEnvManager take asynchronous steps (#2265)

SubprocessEnvManager takes steps synchronously to reproduce old
behavior, meaning all parallel environments will need to wait for
the slowest environment to take a step.  If some steps take much
longer than others, this can lead to a substantial overall slowdown
in practice.  We've seen extreme cases where we see almost a 2x
speedup from using asynchronous stepping, with no downside for our
faster environments.  (Bouncer 16% improvement, Walker 14% improvement 
in tests).

This PR changes the SubprocessEnvManager to use async stepping.
This means on the "step" call the environment manager will enqueue
step requests to workers, and then only wait until at least one
step has been completed before returning.
/develop-generalizationTraining-TrainerController
GitHub 5 年前
当前提交
a802d0d7
共有 2 个文件被更改,包括 117 次插入61 次删除
  1. 115
      ml-agents-envs/mlagents/envs/subprocess_env_manager.py
  2. 63
      ml-agents-envs/mlagents/envs/tests/test_subprocess_env_manager.py

115
ml-agents-envs/mlagents/envs/subprocess_env_manager.py


import cloudpickle
from mlagents.envs import UnityEnvironment
from multiprocessing import Process, Pipe
from multiprocessing import Process, Pipe, Queue
from queue import Empty as EmptyQueueException
from mlagents.envs.timers import timed, hierarchical_timer
from mlagents.envs import AllBrainInfo, BrainParameters, ActionInfo
from mlagents.envs.timers import timed
from mlagents.envs import BrainParameters, ActionInfo
class EnvironmentCommand(NamedTuple):

self.conn = conn
self.previous_step: StepInfo = StepInfo(None, {}, None)
self.previous_all_action_info: Dict[str, ActionInfo] = {}
self.waiting = False
def send(self, name: str, payload=None):
try:

self.process.join()
def worker(parent_conn: Connection, pickled_env_factory: str, worker_id: int):
def worker(
parent_conn: Connection, step_queue: Queue, pickled_env_factory: str, worker_id: int
):
env_factory: Callable[[int], UnityEnvironment] = cloudpickle.loads(
pickled_env_factory
)

cmd: EnvironmentCommand = parent_conn.recv()
if cmd.name == "step":
all_action_info = cmd.payload
# When an environment is "global_done" it means automatic agent reset won't occur, so we need
# to perform an academy reset.
if env.global_done:
all_brain_info = env.reset()
else:

texts[brain_name] = action_info.text
values[brain_name] = action_info.value
all_brain_info = env.step(actions, memories, texts, values)
_send_response("step", all_brain_info)
step_queue.put(EnvironmentResponse("step", worker_id, all_brain_info))
elif cmd.name == "external_brains":
_send_response("external_brains", env.external_brains)
elif cmd.name == "reset_parameters":

except KeyboardInterrupt:
print("UnityEnvironment worker: keyboard interrupt")
finally:
step_queue.close()
env.close()

):
super().__init__()
self.env_workers: List[UnityEnvWorker] = []
self.step_queue: Queue = Queue()
self.env_workers.append(self.create_worker(worker_idx, env_factory))
def get_last_steps(self):
return [ew.previous_step for ew in self.env_workers]
self.env_workers.append(
self.create_worker(worker_idx, self.step_queue, env_factory)
)
worker_id: int, env_factory: Callable[[int], BaseUnityEnvironment]
worker_id: int,
step_queue: Queue,
env_factory: Callable[[int], BaseUnityEnvironment],
) -> UnityEnvWorker:
parent_conn, child_conn = Pipe()

child_process = Process(
target=worker, args=(child_conn, pickled_env_factory, worker_id)
target=worker, args=(child_conn, step_queue, pickled_env_factory, worker_id)
def _queue_steps(self) -> None:
for env_worker in self.env_workers:
if not env_worker.waiting:
env_action_info = self._take_step(env_worker.previous_step)
env_worker.previous_all_action_info = env_action_info
env_worker.send("step", env_action_info)
env_worker.waiting = True
for env_worker in self.env_workers:
all_action_info = self._take_step(env_worker.previous_step)
env_worker.previous_all_action_info = all_action_info
env_worker.send("step", all_action_info)
# Queue steps for any workers which aren't in the "waiting" state.
self._queue_steps()
worker_steps: List[EnvironmentResponse] = []
step_workers: Set[int] = set()
# Poll the step queue for completed steps from environment workers until we retrieve
# 1 or more, which we will then return as StepInfos
while len(worker_steps) < 1:
try:
while True:
step = self.step_queue.get_nowait()
self.env_workers[step.worker_id].waiting = False
if step.worker_id not in step_workers:
worker_steps.append(step)
step_workers.add(step.worker_id)
except EmptyQueueException:
pass
with hierarchical_timer("recv"):
step_brain_infos: List[AllBrainInfo] = [
self.env_workers[i].recv().payload for i in range(len(self.env_workers))
]
steps = []
for i in range(len(step_brain_infos)):
env_worker = self.env_workers[i]
step_info = StepInfo(
env_worker.previous_step.current_all_brain_info,
step_brain_infos[i],
env_worker.previous_all_action_info,
)
env_worker.previous_step = step_info
steps.append(step_info)
return steps
step_infos = self._postprocess_steps(worker_steps)
return step_infos
self._broadcast_message("reset", (config, train_mode, custom_reset_parameters))
reset_results = [
self.env_workers[i].recv().payload for i in range(len(self.env_workers))
]
for i in range(len(reset_results)):
env_worker = self.env_workers[i]
env_worker.previous_step = StepInfo(None, reset_results[i], None)
while any([ew.waiting for ew in self.env_workers]):
if not self.step_queue.empty():
step = self.step_queue.get_nowait()
self.env_workers[step.worker_id].waiting = False
# First enqueue reset commands for all workers so that they reset in parallel
for ew in self.env_workers:
ew.send("reset", (config, train_mode, custom_reset_parameters))
# Next (synchronously) collect the reset observations from each worker in sequence
for ew in self.env_workers:
ew.previous_step = StepInfo(None, ew.recv().payload, None)
return list(map(lambda ew: ew.previous_step, self.env_workers))
@property

self.env_workers[0].send("reset_parameters")
return self.env_workers[0].recv().payload
def close(self):
def close(self) -> None:
self.step_queue.close()
self.step_queue.join_thread()
def _broadcast_message(self, name: str, payload=None):
for env in self.env_workers:
env.send(name, payload)
def _postprocess_steps(
self, env_steps: List[EnvironmentResponse]
) -> List[StepInfo]:
step_infos = []
for step in env_steps:
env_worker = self.env_workers[step.worker_id]
new_step = StepInfo(
env_worker.previous_step.current_all_brain_info,
step.payload,
env_worker.previous_all_action_info,
)
step_infos.append(new_step)
env_worker.previous_step = new_step
return step_infos
@timed
def _take_step(self, last_step: StepInfo) -> Dict[str, ActionInfo]:

63
ml-agents-envs/mlagents/envs/tests/test_subprocess_env_manager.py


from unittest.mock import Mock, MagicMock
import unittest
import cloudpickle
from mlagents.envs.subprocess_env_manager import StepInfo
from queue import Empty as EmptyQueue
from mlagents.envs.subprocess_env_manager import (
SubprocessEnvManager,

self.conn = None
self.send = Mock()
self.recv = Mock(return_value=resp)
self.waiting = False
class SubprocessEnvManagerTest(unittest.TestCase):

# Creates two processes
env.create_worker.assert_has_calls(
[mock.call(0, mock_env_factory), mock.call(1, mock_env_factory)]
[
mock.call(0, env.step_queue, mock_env_factory),
mock.call(1, env.step_queue, mock_env_factory),
]
)
self.assertEqual(len(env.env_workers), 2)

return env_mock
mock_parent_connection = Mock()
mock_step_queue = Mock()
step_command = EnvironmentCommand("step", (None, None, None, None))
close_command = EnvironmentCommand("close")
mock_parent_connection.recv.side_effect = [step_command, close_command]

mock_parent_connection, cloudpickle.dumps(mock_global_done_env_factory), 0
mock_parent_connection,
mock_step_queue,
cloudpickle.dumps(mock_global_done_env_factory),
0,
)
# recv called twice to get step and close command

mock_parent_connection.send.assert_called_with(
mock_step_queue.put.assert_called_with(
SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory: MockEnvWorker(
worker_id, EnvironmentResponse("reset", worker_id, worker_id)
)
manager = SubprocessEnvManager(mock_env_factory, 1)
params = {"test": "params"}
manager.reset(params, False)

SubprocessEnvManager.create_worker = lambda em, worker_id, env_factory: MockEnvWorker(
SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory: MockEnvWorker(
worker_id, EnvironmentResponse("reset", worker_id, worker_id)
)
manager = SubprocessEnvManager(mock_env_factory, 4)

)
assert res == list(map(lambda ew: ew.previous_step, manager.env_workers))
def test_step_takes_steps_for_all_envs(self):
SubprocessEnvManager.create_worker = lambda em, worker_id, env_factory: MockEnvWorker(
def test_step_takes_steps_for_all_non_waiting_envs(self):
SubprocessEnvManager.create_worker = lambda em, worker_id, step_queue, env_factory: MockEnvWorker(
manager = SubprocessEnvManager(mock_env_factory, 2)
manager = SubprocessEnvManager(mock_env_factory, 3)
manager.step_queue = Mock()
manager.step_queue.get_nowait.side_effect = [
EnvironmentResponse("step", 0, 0),
EnvironmentResponse("step", 1, 1),
EmptyQueue(),
]
last_steps = [Mock(), Mock()]
last_steps = [Mock(), Mock(), Mock()]
manager.env_workers[2].previous_step = last_steps[2]
manager.env_workers[2].waiting = True
env.send.assert_called_with("step", step_mock)
env.recv.assert_called()
# Check that the "last steps" are set to the value returned for each step
self.assertEqual(
manager.env_workers[i].previous_step.current_all_brain_info, i
)
self.assertEqual(
manager.env_workers[i].previous_step.previous_all_brain_info,
last_steps[i].current_all_brain_info,
)
assert res == list(map(lambda ew: ew.previous_step, manager.env_workers))
if i < 2:
env.send.assert_called_with("step", step_mock)
manager.step_queue.get_nowait.assert_called()
# Check that the "last steps" are set to the value returned for each step
self.assertEqual(
manager.env_workers[i].previous_step.current_all_brain_info, i
)
self.assertEqual(
manager.env_workers[i].previous_step.previous_all_brain_info,
last_steps[i].current_all_brain_info,
)
assert res == [
manager.env_workers[0].previous_step,
manager.env_workers[1].previous_step,
]
正在加载...
取消
保存