|
|
|
|
|
|
from multiprocessing.connection import Connection |
|
|
|
from queue import Empty as EmptyQueueException |
|
|
|
from mlagents.envs.base_unity_environment import BaseUnityEnvironment |
|
|
|
from mlagents.envs.env_manager import EnvManager, StepInfo |
|
|
|
from mlagents.envs.env_manager import EnvManager, EnvironmentStep |
|
|
|
from mlagents.envs.timers import ( |
|
|
|
TimerNode, |
|
|
|
timed, |
|
|
|
|
|
|
self.process = process |
|
|
|
self.worker_id = worker_id |
|
|
|
self.conn = conn |
|
|
|
self.previous_step: StepInfo = StepInfo(None, {}, None) |
|
|
|
self.previous_step: EnvironmentStep = EnvironmentStep(None, {}, None) |
|
|
|
self.previous_all_action_info: Dict[str, ActionInfo] = {} |
|
|
|
self.waiting = False |
|
|
|
|
|
|
|
|
|
|
env_worker.send("step", env_action_info) |
|
|
|
env_worker.waiting = True |
|
|
|
|
|
|
|
def step(self) -> List[StepInfo]: |
|
|
|
def step(self) -> List[EnvironmentStep]: |
|
|
|
# Queue steps for any workers which aren't in the "waiting" state. |
|
|
|
self._queue_steps() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset( |
|
|
|
self, config=None, train_mode=True, custom_reset_parameters=None |
|
|
|
) -> List[StepInfo]: |
|
|
|
) -> List[EnvironmentStep]: |
|
|
|
while any([ew.waiting for ew in self.env_workers]): |
|
|
|
if not self.step_queue.empty(): |
|
|
|
step = self.step_queue.get_nowait() |
|
|
|
|
|
|
ew.send("reset", (config, train_mode, custom_reset_parameters)) |
|
|
|
# Next (synchronously) collect the reset observations from each worker in sequence |
|
|
|
for ew in self.env_workers: |
|
|
|
ew.previous_step = StepInfo(None, ew.recv().payload, None) |
|
|
|
ew.previous_step = EnvironmentStep(None, ew.recv().payload, None) |
|
|
|
return list(map(lambda ew: ew.previous_step, self.env_workers)) |
|
|
|
|
|
|
|
@property |
|
|
|
|
|
|
|
|
|
|
def _postprocess_steps( |
|
|
|
self, env_steps: List[EnvironmentResponse] |
|
|
|
) -> List[StepInfo]: |
|
|
|
) -> List[EnvironmentStep]: |
|
|
|
new_step = StepInfo( |
|
|
|
new_step = EnvironmentStep( |
|
|
|
env_worker.previous_step.current_all_brain_info, |
|
|
|
payload.all_brain_info, |
|
|
|
env_worker.previous_all_action_info, |
|
|
|
|
|
|
return step_infos |
|
|
|
|
|
|
|
@timed |
|
|
|
def _take_step(self, last_step: StepInfo) -> Dict[str, ActionInfo]: |
|
|
|
def _take_step(self, last_step: EnvironmentStep) -> Dict[str, ActionInfo]: |
|
|
|
all_action_info: Dict[str, ActionInfo] = {} |
|
|
|
for brain_name, brain_info in last_step.current_all_brain_info.items(): |
|
|
|
if brain_name in self.policies: |
|
|
|