您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
79 行
3.2 KiB
79 行
3.2 KiB
from typing import Dict, List
|
|
|
|
from mlagents_envs.base_env import BaseEnv, AgentGroup
|
|
from mlagents.trainers.env_manager import EnvManager, EnvironmentStep, AllStepResult
|
|
from mlagents_envs.timers import timed
|
|
from mlagents.trainers.action_info import ActionInfo
|
|
from mlagents.trainers.brain import BrainParameters
|
|
from mlagents_envs.side_channel.float_properties_channel import FloatPropertiesChannel
|
|
from mlagents.trainers.brain_conversion_utils import group_spec_to_brain_parameters
|
|
|
|
|
|
class SimpleEnvManager(EnvManager):
|
|
"""
|
|
Simple implementation of the EnvManager interface that only handles one BaseEnv at a time.
|
|
This is generally only useful for testing; see SubprocessEnvManager for a production-quality implementation.
|
|
"""
|
|
|
|
def __init__(self, env: BaseEnv, float_prop_channel: FloatPropertiesChannel):
|
|
super().__init__()
|
|
self.shared_float_properties = float_prop_channel
|
|
self.env = env
|
|
self.previous_step: EnvironmentStep = EnvironmentStep.empty(0)
|
|
self.previous_all_action_info: Dict[str, ActionInfo] = {}
|
|
|
|
def step(self) -> List[EnvironmentStep]:
|
|
all_action_info = self._take_step(self.previous_step)
|
|
self.previous_all_action_info = all_action_info
|
|
|
|
for brain_name, action_info in all_action_info.items():
|
|
self.env.set_actions(brain_name, action_info.action)
|
|
self.env.step()
|
|
all_step_result = self._generate_all_results()
|
|
|
|
step_info = EnvironmentStep(all_step_result, 0, self.previous_all_action_info)
|
|
self.previous_step = step_info
|
|
return [step_info]
|
|
|
|
def reset(
|
|
self, config: Dict[AgentGroup, float] = None
|
|
) -> List[EnvironmentStep]: # type: ignore
|
|
if config is not None:
|
|
for k, v in config.items():
|
|
self.shared_float_properties.set_property(k, v)
|
|
self.env.reset()
|
|
all_step_result = self._generate_all_results()
|
|
self.previous_step = EnvironmentStep(all_step_result, 0, {})
|
|
return [self.previous_step]
|
|
|
|
@property
|
|
def external_brains(self) -> Dict[AgentGroup, BrainParameters]:
|
|
result = {}
|
|
for brain_name in self.env.get_agent_groups():
|
|
result[brain_name] = group_spec_to_brain_parameters(
|
|
brain_name, self.env.get_agent_group_spec(brain_name)
|
|
)
|
|
return result
|
|
|
|
@property
|
|
def get_properties(self) -> Dict[AgentGroup, float]:
|
|
return self.shared_float_properties.get_property_dict_copy()
|
|
|
|
def close(self):
|
|
self.env.close()
|
|
|
|
@timed
|
|
def _take_step(self, last_step: EnvironmentStep) -> Dict[AgentGroup, ActionInfo]:
|
|
all_action_info: Dict[str, ActionInfo] = {}
|
|
for brain_name, step_info in last_step.current_all_step_result.items():
|
|
all_action_info[brain_name] = self.policies[brain_name].get_action(
|
|
step_info,
|
|
0, # As there is only one worker, we assign the worker_id to 0.
|
|
)
|
|
return all_action_info
|
|
|
|
def _generate_all_results(self) -> AllStepResult:
|
|
all_step_result: AllStepResult = {}
|
|
for brain_name in self.env.get_agent_groups():
|
|
all_step_result[brain_name] = self.env.get_step_result(brain_name)
|
|
return all_step_result
|