浏览代码

Fix lesson incrementing (#4279)

* Fix lesson incrementing

* Add warning and test

* Add test for lesson pasing

Co-authored-by: Chris Elion <chris.elion@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
d1bf56e9
共有 4 个文件被更改,包括 82 次插入4 次删除
  1. 2
      ml-agents/mlagents/trainers/environment_parameter_manager.py
  2. 8
      ml-agents/mlagents/trainers/exception.py
  3. 12
      ml-agents/mlagents/trainers/settings.py
  4. 64
      ml-agents/mlagents/trainers/tests/test_env_param_manager.py

2
ml-agents/mlagents/trainers/environment_parameter_manager.py


lesson = settings.curriculum[lesson_num]
if (
lesson.completion_criteria is not None
and len(settings.curriculum) > lesson_num
and len(settings.curriculum) > lesson_num + 1
):
behavior_to_consider = lesson.completion_criteria.behavior
if behavior_to_consider in trainer_steps:

8
ml-agents/mlagents/trainers/exception.py


pass
class TrainerConfigWarning(Warning):
"""
Any warning related to the configuration of trainers in the ML-Agents Toolkit.
"""
pass
class CurriculumError(TrainerError):
"""
Any error related to training with a curriculum.

12
ml-agents/mlagents/trainers/settings.py


import warnings
import attr
import cattr
from typing import Dict, Optional, List, Any, DefaultDict, Mapping, Tuple, Union

from mlagents.trainers.cli_utils import StoreConfigFile, DetectDefault, parser
from mlagents.trainers.cli_utils import load_config
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.exception import TrainerConfigError, TrainerConfigWarning
from mlagents_envs import logging_util
from mlagents_envs.side_channel.environment_parameters_channel import (

def _check_lesson_chain(lessons, parameter_name):
"""
Ensures that when using curriculum, all non-terminal lessons have a valid
CompletionCriteria
CompletionCriteria, and that the terminal lesson does not contain a CompletionCriteria.
"""
num_lessons = len(lessons)
for index, lesson in enumerate(lessons):

)
if index == num_lessons - 1 and lesson.completion_criteria is not None:
warnings.warn(
f"Your final lesson definition contains completion_criteria for {parameter_name}."
f"It will be ignored.",
TrainerConfigWarning,
)
@staticmethod

64
ml-agents/mlagents/trainers/tests/test_env_param_manager.py


import yaml
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.exception import TrainerConfigError, TrainerConfigWarning
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.settings import (
RunOptions,

"""
test_bad_curriculum_all_competion_criteria_config_yaml = """
environment_parameters:
param_1:
curriculum:
- name: Lesson1
completion_criteria:
measure: reward
behavior: fake_behavior
threshold: 30
min_lesson_length: 100
require_reset: true
value: 1
- name: Lesson2
completion_criteria:
measure: reward
behavior: fake_behavior
threshold: 30
min_lesson_length: 100
require_reset: true
value: 2
- name: Lesson3
completion_criteria:
measure: reward
behavior: fake_behavior
threshold: 30
min_lesson_length: 100
require_reset: true
value:
sampler_type: uniform
sampler_parameters:
min_value: 1
max_value: 3
"""
def test_curriculum_raises_all_completion_criteria_conversion():
with pytest.warns(TrainerConfigWarning):
run_options = RunOptions.from_dict(
yaml.safe_load(test_bad_curriculum_all_competion_criteria_config_yaml)
)
param_manager = EnvironmentParameterManager(
run_options.environment_parameters, 1337, False
)
assert param_manager.update_lessons(
trainer_steps={"fake_behavior": 500},
trainer_max_steps={"fake_behavior": 1000},
trainer_reward_buffer={"fake_behavior": [1000] * 101},
) == (True, True)
assert param_manager.update_lessons(
trainer_steps={"fake_behavior": 500},
trainer_max_steps={"fake_behavior": 1000},
trainer_reward_buffer={"fake_behavior": [1000] * 101},
) == (True, True)
assert param_manager.update_lessons(
trainer_steps={"fake_behavior": 500},
trainer_max_steps={"fake_behavior": 1000},
trainer_reward_buffer={"fake_behavior": [1000] * 101},
) == (False, False)
assert param_manager.get_current_lesson_number() == {"param_1": 2}
test_everything_config_yaml = """

正在加载...
取消
保存