浏览代码

add to_string for samplers (#4484)

Co-authored-by: Chris Elion <chris.elion@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
efa2a704
共有 3 个文件被更改,包括 54 次插入4 次删除
  1. 12
      ml-agents/mlagents/trainers/environment_parameter_manager.py
  2. 30
      ml-agents/mlagents/trainers/settings.py
  3. 16
      ml-agents/mlagents/trainers/tests/test_settings.py

12
ml-agents/mlagents/trainers/environment_parameter_manager.py


lesson_num = GlobalTrainingStatus.get_parameter_state(
param_name, StatusType.LESSON_NUM
)
next_lesson_num = lesson_num + 1
and len(settings.curriculum) > lesson_num + 1
and len(settings.curriculum) > next_lesson_num
):
behavior_to_consider = lesson.completion_criteria.behavior
if behavior_to_consider in trainer_steps:

self._smoothed_values[param_name] = new_smoothing
if must_increment:
GlobalTrainingStatus.set_parameter_state(
param_name, StatusType.LESSON_NUM, lesson_num + 1
param_name, StatusType.LESSON_NUM, next_lesson_num
new_lesson_name = settings.curriculum[lesson_num + 1].name
new_lesson_name = settings.curriculum[next_lesson_num].name
new_lesson_value = settings.curriculum[next_lesson_num].value
f"Parameter '{param_name}' has changed. Now in lesson '{new_lesson_name}'"
f"Parameter '{param_name}' has been updated to {new_lesson_value}."
+ f" Now in lesson '{new_lesson_name}'"
)
updated = True
if lesson.completion_criteria.require_reset:

30
ml-agents/mlagents/trainers/settings.py


class ParameterRandomizationSettings(abc.ABC):
seed: int = parser.get_default("seed")
def __str__(self) -> str:
"""
Helper method to output sampler stats to console.
"""
raise TrainerConfigError(f"__str__ not implemented for type {self.__class__}.")
@staticmethod
def structure(
d: Union[Mapping, float], t: type

class ConstantSettings(ParameterRandomizationSettings):
value: float = 0.0
def __str__(self) -> str:
"""
Helper method to output sampler stats to console.
"""
return f"Float: value={self.value}"
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None:
"""
Helper method to send sampler settings over EnvironmentParametersChannel

min_value: float = attr.ib()
max_value: float = 1.0
def __str__(self) -> str:
"""
Helper method to output sampler stats to console.
"""
return f"Uniform sampler: min={self.min_value}, max={self.max_value}"
@min_value.default
def _min_value_default(self):
return 0.0

mean: float = 1.0
st_dev: float = 1.0
def __str__(self) -> str:
"""
Helper method to output sampler stats to console.
"""
return f"Gaussian sampler: mean={self.mean}, stddev={self.st_dev}"
def apply(self, key: str, env_channel: EnvironmentParametersChannel) -> None:
"""
Helper method to send sampler settings over EnvironmentParametersChannel

@attr.s(auto_attribs=True)
class MultiRangeUniformSettings(ParameterRandomizationSettings):
intervals: List[Tuple[float, float]] = attr.ib()
def __str__(self) -> str:
"""
Helper method to output sampler stats to console.
"""
return f"MultiRangeUniform sampler: intervals={self.intervals}"
@intervals.default
def _intervals_default(self):

16
ml-agents/mlagents/trainers/tests/test_settings.py


assert isinstance(
env_param_settings["length"].curriculum[0].value, MultiRangeUniformSettings
)
# Check __str__ is correct
assert (
str(env_param_settings["mass"].curriculum[0].value)
== "Uniform sampler: min=1.0, max=2.0"
)
assert (
str(env_param_settings["scale"].curriculum[0].value)
== "Gaussian sampler: mean=1.0, stddev=2.0"
)
assert (
str(env_param_settings["length"].curriculum[0].value)
== "MultiRangeUniform sampler: intervals=[(1.0, 2.0), (3.0, 4.0)]"
)
assert str(env_param_settings["gravity"].curriculum[0].value) == "Float: value=1"
assert isinstance(
env_param_settings["wall_height"].curriculum[0].value, ConstantSettings
)

正在加载...
取消
保存