浏览代码

Renamed "StepInfo" to "EnvironmentStep"

This change was requested for clarity during the async EnvManager
PR.  It's a simple rename of the StepInfo class.
/develop-gpu-test
Jonathan Harper 5 年前
当前提交
2f083c8a
共有 5 个文件被更改,包括 23 次插入23 次删除
  1. 6
      ml-agents-envs/mlagents/envs/env_manager.py
  2. 14
      ml-agents-envs/mlagents/envs/simple_env_manager.py
  3. 16
      ml-agents-envs/mlagents/envs/subprocess_env_manager.py
  4. 6
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py
  5. 4
      ml-agents/mlagents/trainers/trainer_controller.py

6
ml-agents-envs/mlagents/envs/env_manager.py


from mlagents.envs import AllBrainInfo, BrainParameters, Policy, ActionInfo
class StepInfo(NamedTuple):
class EnvironmentStep(NamedTuple):
previous_all_brain_info: Optional[AllBrainInfo]
current_all_brain_info: AllBrainInfo
brain_name_to_action_info: Optional[Dict[str, ActionInfo]]

self.policies[brain_name] = policy
@abstractmethod
def step(self) -> List[StepInfo]:
def step(self) -> List[EnvironmentStep]:
def reset(self, config=None, train_mode=True) -> List[StepInfo]:
def reset(self, config=None, train_mode=True) -> List[EnvironmentStep]:
pass
@property

14
ml-agents-envs/mlagents/envs/simple_env_manager.py


from typing import Any, Dict, List
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.env_manager import EnvManager, StepInfo
from mlagents.envs.env_manager import EnvManager, EnvironmentStep
from mlagents.envs.timers import timed
from mlagents.envs import ActionInfo, BrainParameters

def __init__(self, env: BaseUnityEnvironment):
super().__init__()
self.env = env
self.previous_step: StepInfo = StepInfo(None, {}, None)
self.previous_step: EnvironmentStep = EnvironmentStep(None, {}, None)
def step(self) -> List[StepInfo]:
def step(self) -> List[EnvironmentStep]:
all_action_info = self._take_step(self.previous_step)
self.previous_all_action_info = all_action_info

all_brain_info = self.env.step(actions, memories, texts, values)
step_brain_info = all_brain_info
step_info = StepInfo(
step_info = EnvironmentStep(
self.previous_step.current_all_brain_info,
step_brain_info,
self.previous_all_action_info,

config: Dict[str, float] = None,
train_mode: bool = True,
custom_reset_parameters: Any = None,
) -> List[StepInfo]: # type: ignore
) -> List[EnvironmentStep]: # type: ignore
self.previous_step = StepInfo(None, all_brain_info, None)
self.previous_step = EnvironmentStep(None, all_brain_info, None)
return [self.previous_step]
@property

self.env.close()
@timed
def _take_step(self, last_step: StepInfo) -> Dict[str, ActionInfo]:
def _take_step(self, last_step: EnvironmentStep) -> Dict[str, ActionInfo]:
all_action_info: Dict[str, ActionInfo] = {}
for brain_name, brain_info in last_step.current_all_brain_info.items():
all_action_info[brain_name] = self.policies[brain_name].get_action(

16
ml-agents-envs/mlagents/envs/subprocess_env_manager.py


from multiprocessing.connection import Connection
from queue import Empty as EmptyQueueException
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.env_manager import EnvManager, StepInfo
from mlagents.envs.env_manager import EnvManager, EnvironmentStep
from mlagents.envs.timers import (
TimerNode,
timed,

self.process = process
self.worker_id = worker_id
self.conn = conn
self.previous_step: StepInfo = StepInfo(None, {}, None)
self.previous_step: EnvironmentStep = EnvironmentStep(None, {}, None)
self.previous_all_action_info: Dict[str, ActionInfo] = {}
self.waiting = False

env_worker.send("step", env_action_info)
env_worker.waiting = True
def step(self) -> List[StepInfo]:
def step(self) -> List[EnvironmentStep]:
# Queue steps for any workers which aren't in the "waiting" state.
self._queue_steps()

def reset(
self, config=None, train_mode=True, custom_reset_parameters=None
) -> List[StepInfo]:
) -> List[EnvironmentStep]:
while any([ew.waiting for ew in self.env_workers]):
if not self.step_queue.empty():
step = self.step_queue.get_nowait()

ew.send("reset", (config, train_mode, custom_reset_parameters))
# Next (synchronously) collect the reset observations from each worker in sequence
for ew in self.env_workers:
ew.previous_step = StepInfo(None, ew.recv().payload, None)
ew.previous_step = EnvironmentStep(None, ew.recv().payload, None)
return list(map(lambda ew: ew.previous_step, self.env_workers))
@property

def _postprocess_steps(
self, env_steps: List[EnvironmentResponse]
) -> List[StepInfo]:
) -> List[EnvironmentStep]:
new_step = StepInfo(
new_step = EnvironmentStep(
env_worker.previous_step.current_all_brain_info,
payload.all_brain_info,
env_worker.previous_all_action_info,

return step_infos
@timed
def _take_step(self, last_step: StepInfo) -> Dict[str, ActionInfo]:
def _take_step(self, last_step: EnvironmentStep) -> Dict[str, ActionInfo]:
all_action_info: Dict[str, ActionInfo] = {}
for brain_name, brain_info in last_step.current_all_brain_info.items():
if brain_name in self.policies:

6
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


import pytest
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.envs.subprocess_env_manager import StepInfo
from mlagents.envs.subprocess_env_manager import EnvironmentStep
from mlagents.envs.sampler_class import SamplerManager

def test_take_step_adds_experiences_to_trainer_and_trains():
tc, trainer_mock = trainer_controller_with_take_step_mocks()
old_step_info = StepInfo(Mock(), Mock(), MagicMock())
new_step_info = StepInfo(Mock(), Mock(), MagicMock())
old_step_info = EnvironmentStep(Mock(), Mock(), MagicMock())
new_step_info = EnvironmentStep(Mock(), Mock(), MagicMock())
trainer_mock.is_ready_update = MagicMock(return_value=True)
env_mock = MagicMock()

4
ml-agents/mlagents/trainers/trainer_controller.py


import tensorflow as tf
from time import time
from mlagents.envs.env_manager import StepInfo
from mlagents.envs.env_manager import EnvironmentStep
from mlagents.envs.env_manager import EnvManager
from mlagents.envs.exception import (
UnityEnvironmentException,

"permissions are set correctly.".format(model_path)
)
def _reset_env(self, env: EnvManager) -> List[StepInfo]:
def _reset_env(self, env: EnvManager) -> List[EnvironmentStep]:
"""Resets the environment.
Returns:

正在加载...
取消
保存