Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

138 行
4.4 KiB

import pytest
from unittest.mock import patch, call
from mlagents.trainers.meta_curriculum import MetaCurriculum
from mlagents.trainers.exception import MetaCurriculumError
class MetaCurriculumTest(MetaCurriculum):
"""This class allows us to test MetaCurriculum objects without calling
MetaCurriculum's __init__ function.
"""
def __init__(self, brains_to_curriculums):
self._brains_to_curriculums = brains_to_curriculums
@pytest.fixture
def default_reset_parameters():
return {"param1": 1, "param2": 2, "param3": 3}
@pytest.fixture
def more_reset_parameters():
return {"param4": 4, "param5": 5, "param6": 6}
@pytest.fixture
def measure_vals():
return {"Brain1": 0.2, "Brain2": 0.3}
@pytest.fixture
def reward_buff_sizes():
return {"Brain1": 7, "Brain2": 8}
@patch("mlagents.trainers.curriculum.Curriculum.get_config", return_value={})
@patch("mlagents.trainers.curriculum.Curriculum.__init__", return_value=None)
@patch("os.listdir", return_value=["Brain1.json", "Brain2.json"])
def test_init_meta_curriculum_happy_path(
listdir, mock_curriculum_init, mock_curriculum_get_config, default_reset_parameters
):
meta_curriculum = MetaCurriculum("test/", default_reset_parameters)
assert len(meta_curriculum.brains_to_curriculums) == 2
assert "Brain1" in meta_curriculum.brains_to_curriculums
assert "Brain2" in meta_curriculum.brains_to_curriculums
calls = [
call("test/Brain1.json", default_reset_parameters),
call("test/Brain2.json", default_reset_parameters),
]
mock_curriculum_init.assert_has_calls(calls)
@patch("os.listdir", side_effect=NotADirectoryError())
def test_init_meta_curriculum_bad_curriculum_folder_raises_error(listdir):
with pytest.raises(MetaCurriculumError):
MetaCurriculum("test/", default_reset_parameters)
@patch("mlagents.trainers.curriculum.Curriculum")
@patch("mlagents.trainers.curriculum.Curriculum")
def test_set_lesson_nums(curriculum_a, curriculum_b):
meta_curriculum = MetaCurriculumTest(
{"Brain1": curriculum_a, "Brain2": curriculum_b}
)
meta_curriculum.lesson_nums = {"Brain1": 1, "Brain2": 3}
assert curriculum_a.lesson_num == 1
assert curriculum_b.lesson_num == 3
@patch("mlagents.trainers.curriculum.Curriculum")
@patch("mlagents.trainers.curriculum.Curriculum")
def test_increment_lessons(curriculum_a, curriculum_b, measure_vals):
meta_curriculum = MetaCurriculumTest(
{"Brain1": curriculum_a, "Brain2": curriculum_b}
)
meta_curriculum.increment_lessons(measure_vals)
curriculum_a.increment_lesson.assert_called_with(0.2)
curriculum_b.increment_lesson.assert_called_with(0.3)
@patch("mlagents.trainers.curriculum.Curriculum")
@patch("mlagents.trainers.curriculum.Curriculum")
def test_increment_lessons_with_reward_buff_sizes(
curriculum_a, curriculum_b, measure_vals, reward_buff_sizes
):
curriculum_a.min_lesson_length = 5
curriculum_b.min_lesson_length = 10
meta_curriculum = MetaCurriculumTest(
{"Brain1": curriculum_a, "Brain2": curriculum_b}
)
meta_curriculum.increment_lessons(measure_vals, reward_buff_sizes=reward_buff_sizes)
curriculum_a.increment_lesson.assert_called_with(0.2)
curriculum_b.increment_lesson.assert_not_called()
@patch("mlagents.trainers.curriculum.Curriculum")
@patch("mlagents.trainers.curriculum.Curriculum")
def test_set_all_curriculums_to_lesson_num(curriculum_a, curriculum_b):
meta_curriculum = MetaCurriculumTest(
{"Brain1": curriculum_a, "Brain2": curriculum_b}
)
meta_curriculum.set_all_curriculums_to_lesson_num(2)
assert curriculum_a.lesson_num == 2
assert curriculum_b.lesson_num == 2
@patch("mlagents.trainers.curriculum.Curriculum")
@patch("mlagents.trainers.curriculum.Curriculum")
def test_get_config(
curriculum_a, curriculum_b, default_reset_parameters, more_reset_parameters
):
curriculum_a.get_config.return_value = default_reset_parameters
curriculum_b.get_config.return_value = default_reset_parameters
meta_curriculum = MetaCurriculumTest(
{"Brain1": curriculum_a, "Brain2": curriculum_b}
)
assert meta_curriculum.get_config() == default_reset_parameters
curriculum_b.get_config.return_value = more_reset_parameters
new_reset_parameters = dict(default_reset_parameters)
new_reset_parameters.update(more_reset_parameters)
assert meta_curriculum.get_config() == new_reset_parameters