浏览代码
Add environment manager for parallel environments (#2209)
Add environment manager for parallel environments (#2209)
Previously in v0.8 we added parallel environments via the SubprocessUnityEnvironment, which exposed the same abstraction as UnityEnvironment while actually wrapping many parallel environments via subprocesses. Wrapping many environments with the same interface as a single environment had some downsides, however: * Ordering needed to be preserved for agents across different envs, complicating the SubprocessEnvironment logic * Asynchronous environments with steps taken out of sync with the trainer aren't viable with the Environment abstraction This PR introduces a new EnvManager abstraction which exposes a reduced subset of the UnityEnvironment abstraction and a SubprocessEnvManager implementation which replaces the SubprocessUnityEnvironment./develop-generalizationTraining-TrainerController
GitHub
5 年前
当前提交
b05c9ac1
共有 29 个文件被更改,包括 537 次插入 和 519 次删除
-
4ml-agents-envs/mlagents/envs/__init__.py
-
13ml-agents-envs/mlagents/envs/brain.py
-
3ml-agents-envs/mlagents/envs/environment.py
-
3ml-agents/mlagents/trainers/__init__.py
-
4ml-agents/mlagents/trainers/bc/policy.py
-
4ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
-
4ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
-
4ml-agents/mlagents/trainers/components/reward_signals/reward_signal.py
-
4ml-agents/mlagents/trainers/components/reward_signals/reward_signal_factory.py
-
2ml-agents/mlagents/trainers/demo_loader.py
-
6ml-agents/mlagents/trainers/learn.py
-
11ml-agents/mlagents/trainers/models.py
-
4ml-agents/mlagents/trainers/ppo/policy.py
-
13ml-agents/mlagents/trainers/ppo/trainer.py
-
5ml-agents/mlagents/trainers/tests/test_learn.py
-
8ml-agents/mlagents/trainers/tests/test_policy.py
-
45ml-agents/mlagents/trainers/tests/test_ppo.py
-
47ml-agents/mlagents/trainers/tests/test_trainer_controller.py
-
27ml-agents/mlagents/trainers/trainer.py
-
145ml-agents/mlagents/trainers/trainer_controller.py
-
14ml-agents/mlagents/trainers/tf_policy.py
-
38ml-agents-envs/mlagents/envs/env_manager.py
-
10ml-agents-envs/mlagents/envs/policy.py
-
180ml-agents-envs/mlagents/envs/subprocess_env_manager.py
-
110ml-agents-envs/mlagents/envs/tests/test_subprocess_env_manager.py
-
124ml-agents-envs/mlagents/envs/tests/test_subprocess_unity_environment.py
-
224ml-agents-envs/mlagents/envs/subprocess_environment.py
-
0/ml-agents/mlagents/trainers/tf_policy.py
-
0/ml-agents-envs/mlagents/envs/action_info.py
|
|||
from .brain import * |
|||
from .brain import AllBrainInfo, BrainInfo, BrainParameters |
|||
from .action_info import ActionInfo, ActionInfoOutputs |
|||
from .policy import Policy |
|||
from .environment import * |
|||
from .exception import * |
|
|||
from abc import ABC, abstractmethod |
|||
from typing import List, Dict, NamedTuple, Optional |
|||
from mlagents.envs import AllBrainInfo, BrainParameters, Policy, ActionInfo |
|||
|
|||
|
|||
class StepInfo(NamedTuple): |
|||
previous_all_brain_info: Optional[AllBrainInfo] |
|||
current_all_brain_info: AllBrainInfo |
|||
brain_name_to_action_info: Optional[Dict[str, ActionInfo]] |
|||
|
|||
|
|||
class EnvManager(ABC): |
|||
def __init__(self): |
|||
self.policies: Dict[str, Policy] = {} |
|||
|
|||
def set_policy(self, brain_name: str, policy: Policy) -> None: |
|||
self.policies[brain_name] = policy |
|||
|
|||
@abstractmethod |
|||
def step(self) -> List[StepInfo]: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def reset(self, config=None, train_mode=True) -> List[StepInfo]: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def external_brains(self) -> Dict[str, BrainParameters]: |
|||
pass |
|||
|
|||
@property |
|||
@abstractmethod |
|||
def reset_parameters(self) -> Dict[str, float]: |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def close(self): |
|||
pass |
|
|||
from abc import ABC, abstractmethod |
|||
|
|||
from mlagents.envs import BrainInfo |
|||
from mlagents.envs import ActionInfo |
|||
|
|||
|
|||
class Policy(ABC): |
|||
@abstractmethod |
|||
def get_action(self, brain_info: BrainInfo) -> ActionInfo: |
|||
pass |
|
|||
from typing import * |
|||
import cloudpickle |
|||
|
|||
from mlagents.envs import UnityEnvironment |
|||
from multiprocessing import Process, Pipe |
|||
from multiprocessing.connection import Connection |
|||
from mlagents.envs.base_unity_environment import BaseUnityEnvironment |
|||
from mlagents.envs.env_manager import EnvManager, StepInfo |
|||
from mlagents.envs import AllBrainInfo, BrainParameters, ActionInfo |
|||
|
|||
|
|||
class EnvironmentCommand(NamedTuple): |
|||
name: str |
|||
payload: Any = None |
|||
|
|||
|
|||
class EnvironmentResponse(NamedTuple): |
|||
name: str |
|||
worker_id: int |
|||
payload: Any |
|||
|
|||
|
|||
class UnityEnvWorker: |
|||
def __init__(self, process: Process, worker_id: int, conn: Connection): |
|||
self.process = process |
|||
self.worker_id = worker_id |
|||
self.conn = conn |
|||
self.previous_step: StepInfo = StepInfo(None, {}, None) |
|||
self.previous_all_action_info: Dict[str, ActionInfo] = {} |
|||
|
|||
def send(self, name: str, payload=None): |
|||
try: |
|||
cmd = EnvironmentCommand(name, payload) |
|||
self.conn.send(cmd) |
|||
except (BrokenPipeError, EOFError): |
|||
raise KeyboardInterrupt |
|||
|
|||
def recv(self) -> EnvironmentResponse: |
|||
try: |
|||
response: EnvironmentResponse = self.conn.recv() |
|||
return response |
|||
except (BrokenPipeError, EOFError): |
|||
raise KeyboardInterrupt |
|||
|
|||
def close(self): |
|||
try: |
|||
self.conn.send(EnvironmentCommand("close")) |
|||
except (BrokenPipeError, EOFError): |
|||
pass |
|||
self.process.join() |
|||
|
|||
|
|||
def worker(parent_conn: Connection, pickled_env_factory: str, worker_id: int): |
|||
env_factory: Callable[[int], UnityEnvironment] = cloudpickle.loads( |
|||
pickled_env_factory |
|||
) |
|||
env = env_factory(worker_id) |
|||
|
|||
def _send_response(cmd_name, payload): |
|||
parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload)) |
|||
|
|||
try: |
|||
while True: |
|||
cmd: EnvironmentCommand = parent_conn.recv() |
|||
if cmd.name == "step": |
|||
all_action_info = cmd.payload |
|||
if env.global_done: |
|||
all_brain_info = env.reset() |
|||
else: |
|||
actions = {} |
|||
memories = {} |
|||
texts = {} |
|||
values = {} |
|||
for brain_name, action_info in all_action_info.items(): |
|||
actions[brain_name] = action_info.action |
|||
memories[brain_name] = action_info.memory |
|||
texts[brain_name] = action_info.text |
|||
values[brain_name] = action_info.value |
|||
all_brain_info = env.step(actions, memories, texts, values) |
|||
_send_response("step", all_brain_info) |
|||
elif cmd.name == "external_brains": |
|||
_send_response("external_brains", env.external_brains) |
|||
elif cmd.name == "reset_parameters": |
|||
_send_response("reset_parameters", env.reset_parameters) |
|||
elif cmd.name == "reset": |
|||
all_brain_info = env.reset(cmd.payload[0], cmd.payload[1]) |
|||
_send_response("reset", all_brain_info) |
|||
elif cmd.name == "global_done": |
|||
_send_response("global_done", env.global_done) |
|||
elif cmd.name == "close": |
|||
break |
|||
except KeyboardInterrupt: |
|||
print("UnityEnvironment worker: keyboard interrupt") |
|||
finally: |
|||
env.close() |
|||
|
|||
|
|||
class SubprocessEnvManager(EnvManager): |
|||
def __init__( |
|||
self, env_factory: Callable[[int], BaseUnityEnvironment], n_env: int = 1 |
|||
): |
|||
super().__init__() |
|||
self.env_workers: List[UnityEnvWorker] = [] |
|||
for worker_idx in range(n_env): |
|||
self.env_workers.append(self.create_worker(worker_idx, env_factory)) |
|||
|
|||
def get_last_steps(self): |
|||
return [ew.previous_step for ew in self.env_workers] |
|||
|
|||
@staticmethod |
|||
def create_worker( |
|||
worker_id: int, env_factory: Callable[[int], BaseUnityEnvironment] |
|||
) -> UnityEnvWorker: |
|||
parent_conn, child_conn = Pipe() |
|||
|
|||
# Need to use cloudpickle for the env factory function since function objects aren't picklable |
|||
# on Windows as of Python 3.6. |
|||
pickled_env_factory = cloudpickle.dumps(env_factory) |
|||
child_process = Process( |
|||
target=worker, args=(child_conn, pickled_env_factory, worker_id) |
|||
) |
|||
child_process.start() |
|||
return UnityEnvWorker(child_process, worker_id, parent_conn) |
|||
|
|||
def step(self) -> List[StepInfo]: |
|||
for env_worker in self.env_workers: |
|||
all_action_info = self._take_step(env_worker.previous_step) |
|||
env_worker.previous_all_action_info = all_action_info |
|||
env_worker.send("step", all_action_info) |
|||
|
|||
step_brain_infos: List[AllBrainInfo] = [ |
|||
self.env_workers[i].recv().payload for i in range(len(self.env_workers)) |
|||
] |
|||
steps = [] |
|||
for i in range(len(step_brain_infos)): |
|||
env_worker = self.env_workers[i] |
|||
step_info = StepInfo( |
|||
env_worker.previous_step.current_all_brain_info, |
|||
step_brain_infos[i], |
|||
env_worker.previous_all_action_info, |
|||
) |
|||
env_worker.previous_step = step_info |
|||
steps.append(step_info) |
|||
return steps |
|||
|
|||
def reset(self, config=None, train_mode=True) -> List[StepInfo]: |
|||
self._broadcast_message("reset", (config, train_mode)) |
|||
reset_results = [ |
|||
self.env_workers[i].recv().payload for i in range(len(self.env_workers)) |
|||
] |
|||
for i in range(len(reset_results)): |
|||
env_worker = self.env_workers[i] |
|||
env_worker.previous_step = StepInfo(None, reset_results[i], None) |
|||
return list(map(lambda ew: ew.previous_step, self.env_workers)) |
|||
|
|||
@property |
|||
def external_brains(self) -> Dict[str, BrainParameters]: |
|||
self.env_workers[0].send("external_brains") |
|||
return self.env_workers[0].recv().payload |
|||
|
|||
@property |
|||
def reset_parameters(self) -> Dict[str, float]: |
|||
self.env_workers[0].send("reset_parameters") |
|||
return self.env_workers[0].recv().payload |
|||
|
|||
def close(self): |
|||
for env in self.env_workers: |
|||
env.close() |
|||
|
|||
def _broadcast_message(self, name: str, payload=None): |
|||
for env in self.env_workers: |
|||
env.send(name, payload) |
|||
|
|||
def _take_step(self, last_step: StepInfo) -> Dict[str, ActionInfo]: |
|||
all_action_info: Dict[str, ActionInfo] = {} |
|||
for brain_name, brain_info in last_step.current_all_brain_info.items(): |
|||
all_action_info[brain_name] = self.policies[brain_name].get_action( |
|||
brain_info |
|||
) |
|||
return all_action_info |
|
|||
import unittest.mock as mock |
|||
from unittest.mock import Mock, MagicMock |
|||
import unittest |
|||
import cloudpickle |
|||
from mlagents.envs.subprocess_env_manager import StepInfo |
|||
|
|||
from mlagents.envs.subprocess_env_manager import ( |
|||
SubprocessEnvManager, |
|||
EnvironmentResponse, |
|||
EnvironmentCommand, |
|||
worker, |
|||
) |
|||
from mlagents.envs.base_unity_environment import BaseUnityEnvironment |
|||
|
|||
|
|||
def mock_env_factory(worker_id: int): |
|||
return mock.create_autospec(spec=BaseUnityEnvironment) |
|||
|
|||
|
|||
class MockEnvWorker: |
|||
def __init__(self, worker_id, resp=None): |
|||
self.worker_id = worker_id |
|||
self.process = None |
|||
self.conn = None |
|||
self.send = Mock() |
|||
self.recv = Mock(return_value=resp) |
|||
|
|||
|
|||
class SubprocessEnvManagerTest(unittest.TestCase): |
|||
def test_environments_are_created(self): |
|||
SubprocessEnvManager.create_worker = MagicMock() |
|||
env = SubprocessEnvManager(mock_env_factory, 2) |
|||
# Creates two processes |
|||
env.create_worker.assert_has_calls( |
|||
[mock.call(0, mock_env_factory), mock.call(1, mock_env_factory)] |
|||
) |
|||
self.assertEqual(len(env.env_workers), 2) |
|||
|
|||
def test_worker_step_resets_on_global_done(self): |
|||
env_mock = Mock() |
|||
env_mock.reset = Mock(return_value="reset_data") |
|||
env_mock.global_done = True |
|||
|
|||
def mock_global_done_env_factory(worker_id: int): |
|||
return env_mock |
|||
|
|||
mock_parent_connection = Mock() |
|||
step_command = EnvironmentCommand("step", (None, None, None, None)) |
|||
close_command = EnvironmentCommand("close") |
|||
mock_parent_connection.recv.side_effect = [step_command, close_command] |
|||
mock_parent_connection.send = Mock() |
|||
|
|||
worker( |
|||
mock_parent_connection, cloudpickle.dumps(mock_global_done_env_factory), 0 |
|||
) |
|||
|
|||
# recv called twice to get step and close command |
|||
self.assertEqual(mock_parent_connection.recv.call_count, 2) |
|||
|
|||
# worker returns the data from the reset |
|||
mock_parent_connection.send.assert_called_with( |
|||
EnvironmentResponse("step", 0, "reset_data") |
|||
) |
|||
|
|||
def test_reset_passes_reset_params(self): |
|||
manager = SubprocessEnvManager(mock_env_factory, 1) |
|||
params = {"test": "params"} |
|||
manager.reset(params, False) |
|||
manager.env_workers[0].send.assert_called_with("reset", (params, False)) |
|||
|
|||
def test_reset_collects_results_from_all_envs(self): |
|||
SubprocessEnvManager.create_worker = lambda em, worker_id, env_factory: MockEnvWorker( |
|||
worker_id, EnvironmentResponse("reset", worker_id, worker_id) |
|||
) |
|||
manager = SubprocessEnvManager(mock_env_factory, 4) |
|||
|
|||
params = {"test": "params"} |
|||
res = manager.reset(params) |
|||
for i, env in enumerate(manager.env_workers): |
|||
env.send.assert_called_with("reset", (params, True)) |
|||
env.recv.assert_called() |
|||
# Check that the "last steps" are set to the value returned for each step |
|||
self.assertEqual( |
|||
manager.env_workers[i].previous_step.current_all_brain_info, i |
|||
) |
|||
assert res == list(map(lambda ew: ew.previous_step, manager.env_workers)) |
|||
|
|||
def test_step_takes_steps_for_all_envs(self): |
|||
SubprocessEnvManager.create_worker = lambda em, worker_id, env_factory: MockEnvWorker( |
|||
worker_id, EnvironmentResponse("step", worker_id, worker_id) |
|||
) |
|||
manager = SubprocessEnvManager(mock_env_factory, 2) |
|||
step_mock = Mock() |
|||
last_steps = [Mock(), Mock()] |
|||
manager.env_workers[0].previous_step = last_steps[0] |
|||
manager.env_workers[1].previous_step = last_steps[1] |
|||
manager._take_step = Mock(return_value=step_mock) |
|||
res = manager.step() |
|||
for i, env in enumerate(manager.env_workers): |
|||
env.send.assert_called_with("step", step_mock) |
|||
env.recv.assert_called() |
|||
# Check that the "last steps" are set to the value returned for each step |
|||
self.assertEqual( |
|||
manager.env_workers[i].previous_step.current_all_brain_info, i |
|||
) |
|||
self.assertEqual( |
|||
manager.env_workers[i].previous_step.previous_all_brain_info, |
|||
last_steps[i].current_all_brain_info, |
|||
) |
|||
assert res == list(map(lambda ew: ew.previous_step, manager.env_workers)) |
|
|||
import unittest.mock as mock |
|||
from unittest.mock import Mock, MagicMock |
|||
import unittest |
|||
|
|||
from mlagents.envs.subprocess_environment import * |
|||
from mlagents.envs import UnityEnvironmentException, BrainInfo |
|||
from mlagents.envs.base_unity_environment import BaseUnityEnvironment |
|||
|
|||
|
|||
def mock_env_factory(worker_id: int): |
|||
return mock.create_autospec(spec=BaseUnityEnvironment) |
|||
|
|||
|
|||
class MockEnvWorker: |
|||
def __init__(self, worker_id): |
|||
self.worker_id = worker_id |
|||
self.process = None |
|||
self.conn = None |
|||
self.send = MagicMock() |
|||
self.recv = MagicMock() |
|||
|
|||
|
|||
class SubprocessEnvironmentTest(unittest.TestCase): |
|||
def test_environments_are_created(self): |
|||
SubprocessUnityEnvironment.create_worker = MagicMock() |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 2) |
|||
# Creates two processes |
|||
self.assertEqual( |
|||
env.create_worker.call_args_list, |
|||
[mock.call(0, mock_env_factory), mock.call(1, mock_env_factory)], |
|||
) |
|||
self.assertEqual(len(env.envs), 2) |
|||
|
|||
def test_step_async_fails_when_waiting(self): |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 0) |
|||
env.waiting = True |
|||
with self.assertRaises(UnityEnvironmentException): |
|||
env.step_async(vector_action=[]) |
|||
|
|||
@staticmethod |
|||
def test_step_async_splits_input_by_agent_count(): |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 0) |
|||
env.env_agent_counts = {"MockBrain": [1, 3, 5]} |
|||
env.envs = [MockEnvWorker(0), MockEnvWorker(1), MockEnvWorker(2)] |
|||
env_0_actions = [[1.0, 2.0]] |
|||
env_1_actions = [[3.0, 4.0]] * 3 |
|||
env_2_actions = [[5.0, 6.0]] * 5 |
|||
vector_action = {"MockBrain": env_0_actions + env_1_actions + env_2_actions} |
|||
env.step_async(vector_action=vector_action) |
|||
env.envs[0].send.assert_called_with( |
|||
"step", ({"MockBrain": env_0_actions}, {}, {}, {}) |
|||
) |
|||
env.envs[1].send.assert_called_with( |
|||
"step", ({"MockBrain": env_1_actions}, {}, {}, {}) |
|||
) |
|||
env.envs[2].send.assert_called_with( |
|||
"step", ({"MockBrain": env_2_actions}, {}, {}, {}) |
|||
) |
|||
|
|||
def test_step_async_sets_waiting(self): |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 0) |
|||
env.step_async(vector_action=[]) |
|||
self.assertTrue(env.waiting) |
|||
|
|||
def test_step_await_fails_if_not_waiting(self): |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 0) |
|||
with self.assertRaises(UnityEnvironmentException): |
|||
env.step_await() |
|||
|
|||
def test_step_await_combines_brain_info(self): |
|||
all_brain_info_env0 = { |
|||
"MockBrain": BrainInfo( |
|||
[], [[1.0, 2.0], [1.0, 2.0]], [], agents=[1, 2], memory=np.zeros((0, 0)) |
|||
) |
|||
} |
|||
all_brain_info_env1 = { |
|||
"MockBrain": BrainInfo( |
|||
[], [[3.0, 4.0]], [], agents=[3], memory=np.zeros((0, 0)) |
|||
) |
|||
} |
|||
env_worker_0 = MockEnvWorker(0) |
|||
env_worker_0.recv.return_value = EnvironmentResponse( |
|||
"step", 0, all_brain_info_env0 |
|||
) |
|||
env_worker_1 = MockEnvWorker(1) |
|||
env_worker_1.recv.return_value = EnvironmentResponse( |
|||
"step", 1, all_brain_info_env1 |
|||
) |
|||
env = SubprocessUnityEnvironment(mock_env_factory, 0) |
|||
env.envs = [env_worker_0, env_worker_1] |
|||
env.waiting = True |
|||
combined_braininfo = env.step_await()["MockBrain"] |
|||
self.assertEqual( |
|||
combined_braininfo.vector_observations.tolist(), |
|||
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0]], |
|||
) |
|||
self.assertEqual(combined_braininfo.agents, ["0-1", "0-2", "1-3"]) |
|||
|
|||
def test_step_resets_on_global_done(self): |
|||
env_mock = Mock() |
|||
env_mock.reset = Mock(return_value="reset_data") |
|||
env_mock.global_done = True |
|||
|
|||
def mock_global_done_env_factory(worker_id: int): |
|||
return env_mock |
|||
|
|||
mock_parent_connection = Mock() |
|||
step_command = EnvironmentCommand("step", (None, None, None, None)) |
|||
close_command = EnvironmentCommand("close") |
|||
mock_parent_connection.recv = Mock() |
|||
mock_parent_connection.recv.side_effect = [step_command, close_command] |
|||
mock_parent_connection.send = Mock() |
|||
|
|||
worker( |
|||
mock_parent_connection, cloudpickle.dumps(mock_global_done_env_factory), 0 |
|||
) |
|||
|
|||
# recv called twice to get step and close command |
|||
self.assertEqual(mock_parent_connection.recv.call_count, 2) |
|||
|
|||
# worker returns the data from the reset |
|||
mock_parent_connection.send.assert_called_with( |
|||
EnvironmentResponse("step", 0, "reset_data") |
|||
) |
|
|||
from typing import * |
|||
import copy |
|||
import numpy as np |
|||
import cloudpickle |
|||
|
|||
from mlagents.envs import UnityEnvironment |
|||
from multiprocessing import Process, Pipe |
|||
from multiprocessing.connection import Connection |
|||
from mlagents.envs.base_unity_environment import BaseUnityEnvironment |
|||
from mlagents.envs import AllBrainInfo, UnityEnvironmentException |
|||
|
|||
|
|||
class EnvironmentCommand(NamedTuple): |
|||
name: str |
|||
payload: Any = None |
|||
|
|||
|
|||
class EnvironmentResponse(NamedTuple): |
|||
name: str |
|||
worker_id: int |
|||
payload: Any |
|||
|
|||
|
|||
class UnityEnvWorker(NamedTuple): |
|||
process: Process |
|||
worker_id: int |
|||
conn: Connection |
|||
|
|||
def send(self, name: str, payload=None): |
|||
try: |
|||
cmd = EnvironmentCommand(name, payload) |
|||
self.conn.send(cmd) |
|||
except (BrokenPipeError, EOFError): |
|||
raise KeyboardInterrupt |
|||
|
|||
def recv(self) -> EnvironmentResponse: |
|||
try: |
|||
response: EnvironmentResponse = self.conn.recv() |
|||
return response |
|||
except (BrokenPipeError, EOFError): |
|||
raise KeyboardInterrupt |
|||
|
|||
def close(self): |
|||
try: |
|||
self.conn.send(EnvironmentCommand("close")) |
|||
except (BrokenPipeError, EOFError): |
|||
pass |
|||
self.process.join() |
|||
|
|||
|
|||
def worker(parent_conn: Connection, pickled_env_factory: str, worker_id: int): |
|||
env_factory: Callable[[int], UnityEnvironment] = cloudpickle.loads( |
|||
pickled_env_factory |
|||
) |
|||
env = env_factory(worker_id) |
|||
|
|||
def _send_response(cmd_name, payload): |
|||
parent_conn.send(EnvironmentResponse(cmd_name, worker_id, payload)) |
|||
|
|||
try: |
|||
while True: |
|||
cmd: EnvironmentCommand = parent_conn.recv() |
|||
if cmd.name == "step": |
|||
vector_action, memory, text_action, value = cmd.payload |
|||
if env.global_done: |
|||
all_brain_info = env.reset() |
|||
else: |
|||
all_brain_info = env.step(vector_action, memory, text_action, value) |
|||
_send_response("step", all_brain_info) |
|||
elif cmd.name == "external_brains": |
|||
_send_response("external_brains", env.external_brains) |
|||
elif cmd.name == "reset_parameters": |
|||
_send_response("reset_parameters", env.reset_parameters) |
|||
elif cmd.name == "reset": |
|||
all_brain_info = env.reset(cmd.payload[0], cmd.payload[1]) |
|||
_send_response("reset", all_brain_info) |
|||
elif cmd.name == "global_done": |
|||
_send_response("global_done", env.global_done) |
|||
elif cmd.name == "close": |
|||
break |
|||
except KeyboardInterrupt: |
|||
print("UnityEnvironment worker: keyboard interrupt") |
|||
finally: |
|||
env.close() |
|||
|
|||
|
|||
class SubprocessUnityEnvironment(BaseUnityEnvironment): |
|||
def __init__( |
|||
self, env_factory: Callable[[int], BaseUnityEnvironment], n_env: int = 1 |
|||
): |
|||
self.envs: List[UnityEnvWorker] = [] |
|||
self.env_agent_counts: Dict[str, List[int]] = {} |
|||
self.waiting = False |
|||
for worker_id in range(n_env): |
|||
self.envs.append(self.create_worker(worker_id, env_factory)) |
|||
|
|||
@staticmethod |
|||
def create_worker( |
|||
worker_id: int, env_factory: Callable[[int], BaseUnityEnvironment] |
|||
) -> UnityEnvWorker: |
|||
parent_conn, child_conn = Pipe() |
|||
|
|||
# Need to use cloudpickle for the env factory function since function objects aren't picklable |
|||
# on Windows as of Python 3.6. |
|||
pickled_env_factory = cloudpickle.dumps(env_factory) |
|||
child_process = Process( |
|||
target=worker, args=(child_conn, pickled_env_factory, worker_id) |
|||
) |
|||
child_process.start() |
|||
return UnityEnvWorker(child_process, worker_id, parent_conn) |
|||
|
|||
def step_async( |
|||