您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
156 行
6.9 KiB
156 行
6.9 KiB
from typing import Dict, List, Tuple, Optional
|
|
from mlagents.trainers.settings import (
|
|
EnvironmentParameterSettings,
|
|
ParameterRandomizationSettings,
|
|
)
|
|
from collections import defaultdict
|
|
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
|
|
|
|
from mlagents_envs.logging_util import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class EnvironmentParameterManager:
|
|
def __init__(
|
|
self,
|
|
settings: Optional[Dict[str, EnvironmentParameterSettings]] = None,
|
|
run_seed: int = -1,
|
|
restore: bool = False,
|
|
):
|
|
"""
|
|
EnvironmentParameterManager manages all the environment parameters of a training
|
|
session. It determines when parameters should change and gives access to the
|
|
current sampler of each parameter.
|
|
:param settings: A dictionary from environment parameter to
|
|
EnvironmentParameterSettings.
|
|
:param run_seed: When the seed is not provided for an environment parameter,
|
|
this seed will be used instead.
|
|
:param restore: If true, the EnvironmentParameterManager will use the
|
|
GlobalTrainingStatus to try and reload the lesson status of each environment
|
|
parameter.
|
|
"""
|
|
if settings is None:
|
|
settings = {}
|
|
self._dict_settings = settings
|
|
for parameter_name in self._dict_settings.keys():
|
|
initial_lesson = GlobalTrainingStatus.get_parameter_state(
|
|
parameter_name, StatusType.LESSON_NUM
|
|
)
|
|
if initial_lesson is None or not restore:
|
|
GlobalTrainingStatus.set_parameter_state(
|
|
parameter_name, StatusType.LESSON_NUM, 0
|
|
)
|
|
self._smoothed_values: Dict[str, float] = defaultdict(float)
|
|
for key in self._dict_settings.keys():
|
|
self._smoothed_values[key] = 0.0
|
|
# Update the seeds of the samplers
|
|
self._set_sampler_seeds(run_seed)
|
|
|
|
def _set_sampler_seeds(self, seed):
|
|
"""
|
|
Sets the seeds for the samplers (if no seed was already present). Note that
|
|
using the provided seed.
|
|
"""
|
|
offset = 0
|
|
for settings in self._dict_settings.values():
|
|
for lesson in settings.curriculum:
|
|
if lesson.value.seed == -1:
|
|
lesson.value.seed = seed + offset
|
|
offset += 1
|
|
|
|
def get_minimum_reward_buffer_size(self, behavior_name: str) -> int:
|
|
"""
|
|
Calculates the minimum size of the reward buffer a behavior must use. This
|
|
method uses the 'min_lesson_length' sampler_parameter to determine this value.
|
|
:param behavior_name: The name of the behavior the minimum reward buffer
|
|
size corresponds to.
|
|
"""
|
|
result = 1
|
|
for settings in self._dict_settings.values():
|
|
for lesson in settings.curriculum:
|
|
if lesson.completion_criteria is not None:
|
|
if lesson.completion_criteria.behavior == behavior_name:
|
|
result = max(
|
|
result, lesson.completion_criteria.min_lesson_length
|
|
)
|
|
return result
|
|
|
|
def get_current_samplers(self) -> Dict[str, ParameterRandomizationSettings]:
|
|
"""
|
|
Creates a dictionary from environment parameter name to their corresponding
|
|
ParameterRandomizationSettings. If curriculum is used, the
|
|
ParameterRandomizationSettings corresponds to the sampler of the current lesson.
|
|
"""
|
|
samplers: Dict[str, ParameterRandomizationSettings] = {}
|
|
for param_name, settings in self._dict_settings.items():
|
|
lesson_num = GlobalTrainingStatus.get_parameter_state(
|
|
param_name, StatusType.LESSON_NUM
|
|
)
|
|
lesson = settings.curriculum[lesson_num]
|
|
samplers[param_name] = lesson.value
|
|
return samplers
|
|
|
|
def get_current_lesson_number(self) -> Dict[str, int]:
|
|
"""
|
|
Creates a dictionary from environment parameter to the current lesson number.
|
|
If not using curriculum, this number is always 0 for that environment parameter.
|
|
"""
|
|
result: Dict[str, int] = {}
|
|
for parameter_name in self._dict_settings.keys():
|
|
result[parameter_name] = GlobalTrainingStatus.get_parameter_state(
|
|
parameter_name, StatusType.LESSON_NUM
|
|
)
|
|
return result
|
|
|
|
def update_lessons(
|
|
self,
|
|
trainer_steps: Dict[str, int],
|
|
trainer_max_steps: Dict[str, int],
|
|
trainer_reward_buffer: Dict[str, List[float]],
|
|
) -> Tuple[bool, bool]:
|
|
"""
|
|
Given progress metrics, calculates if at least one environment parameter is
|
|
in a new lesson and if at least one environment parameter requires the env
|
|
to reset.
|
|
:param trainer_steps: A dictionary from behavior_name to the number of training
|
|
steps this behavior's trainer has performed.
|
|
:param trainer_max_steps: A dictionary from behavior_name to the maximum number
|
|
of training steps this behavior's trainer has performed.
|
|
:param trainer_reward_buffer: A dictionary from behavior_name to the list of
|
|
the most recent episode returns for this behavior's trainer.
|
|
:returns: A tuple of two booleans : (True if any lesson has changed, True if
|
|
environment needs to reset)
|
|
"""
|
|
must_reset = False
|
|
updated = False
|
|
for param_name, settings in self._dict_settings.items():
|
|
lesson_num = GlobalTrainingStatus.get_parameter_state(
|
|
param_name, StatusType.LESSON_NUM
|
|
)
|
|
lesson = settings.curriculum[lesson_num]
|
|
if (
|
|
lesson.completion_criteria is not None
|
|
and len(settings.curriculum) > lesson_num + 1
|
|
):
|
|
behavior_to_consider = lesson.completion_criteria.behavior
|
|
if behavior_to_consider in trainer_steps:
|
|
must_increment, new_smoothing = lesson.completion_criteria.need_increment(
|
|
float(trainer_steps[behavior_to_consider])
|
|
/ float(trainer_max_steps[behavior_to_consider]),
|
|
trainer_reward_buffer[behavior_to_consider],
|
|
self._smoothed_values[param_name],
|
|
)
|
|
self._smoothed_values[param_name] = new_smoothing
|
|
if must_increment:
|
|
GlobalTrainingStatus.set_parameter_state(
|
|
param_name, StatusType.LESSON_NUM, lesson_num + 1
|
|
)
|
|
new_lesson_name = settings.curriculum[lesson_num + 1].name
|
|
logger.info(
|
|
f"Parameter '{param_name}' has changed. Now in lesson '{new_lesson_name}'"
|
|
)
|
|
updated = True
|
|
if lesson.completion_criteria.require_reset:
|
|
must_reset = True
|
|
return updated, must_reset
|