您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
109 行
3.8 KiB
109 行
3.8 KiB
import pytest
|
|
from unittest.mock import patch, call, Mock
|
|
|
|
from mlagents.trainers.meta_curriculum import MetaCurriculum
|
|
from mlagents.trainers.exception import MetaCurriculumError
|
|
|
|
|
|
class MetaCurriculumTest(MetaCurriculum):
|
|
"""This class allows us to test MetaCurriculum objects without calling
|
|
MetaCurriculum's __init__ function.
|
|
"""
|
|
def __init__(self, brains_to_curriculums):
|
|
self._brains_to_curriculums = brains_to_curriculums
|
|
|
|
|
|
@pytest.fixture
|
|
def default_reset_parameters():
|
|
return {'param1' : 1, 'param2' : 2, 'param3' : 3}
|
|
|
|
|
|
@pytest.fixture
|
|
def more_reset_parameters():
|
|
return {'param4' : 4, 'param5' : 5, 'param6' : 6}
|
|
|
|
|
|
@pytest.fixture
|
|
def progresses():
|
|
return {'Brain1' : 0.2, 'Brain2' : 0.3}
|
|
|
|
|
|
@patch('mlagents.trainers.Curriculum.get_config', return_value={})
|
|
@patch('mlagents.trainers.Curriculum.__init__', return_value=None)
|
|
@patch('os.listdir', return_value=['Brain1.json', 'Brain2.json'])
|
|
def test_init_meta_curriculum_happy_path(listdir, mock_curriculum_init,
|
|
mock_curriculum_get_config,
|
|
default_reset_parameters):
|
|
meta_curriculum = MetaCurriculum('test/', default_reset_parameters)
|
|
|
|
assert len(meta_curriculum.brains_to_curriculums) == 2
|
|
|
|
assert 'Brain1' in meta_curriculum.brains_to_curriculums
|
|
assert 'Brain2' in meta_curriculum.brains_to_curriculums
|
|
|
|
calls = [call('test/Brain1.json', default_reset_parameters),
|
|
call('test/Brain2.json', default_reset_parameters)]
|
|
|
|
mock_curriculum_init.assert_has_calls(calls)
|
|
|
|
|
|
@patch('os.listdir', side_effect=NotADirectoryError())
|
|
def test_init_meta_curriculum_bad_curriculum_folder_raises_error(listdir):
|
|
with pytest.raises(MetaCurriculumError):
|
|
MetaCurriculum('test/', default_reset_parameters)
|
|
|
|
|
|
@patch('mlagents.trainers.Curriculum')
|
|
@patch('mlagents.trainers.Curriculum')
|
|
def test_set_lesson_nums(curriculum_a, curriculum_b):
|
|
meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a,
|
|
'Brain2' : curriculum_b})
|
|
|
|
meta_curriculum.lesson_nums = {'Brain1' : 1, 'Brain2' : 3}
|
|
|
|
assert curriculum_a.lesson_num == 1
|
|
assert curriculum_b.lesson_num == 3
|
|
|
|
|
|
|
|
@patch('mlagents.trainers.Curriculum')
|
|
@patch('mlagents.trainers.Curriculum')
|
|
def test_increment_lessons(curriculum_a, curriculum_b, progresses):
|
|
meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a,
|
|
'Brain2' : curriculum_b})
|
|
|
|
meta_curriculum.increment_lessons(progresses)
|
|
|
|
curriculum_a.increment_lesson.assert_called_with(0.2)
|
|
curriculum_b.increment_lesson.assert_called_with(0.3)
|
|
|
|
|
|
@patch('mlagents.trainers.Curriculum')
|
|
@patch('mlagents.trainers.Curriculum')
|
|
def test_set_all_curriculums_to_lesson_num(curriculum_a, curriculum_b):
|
|
meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a,
|
|
'Brain2' : curriculum_b})
|
|
|
|
meta_curriculum.set_all_curriculums_to_lesson_num(2)
|
|
|
|
assert curriculum_a.lesson_num == 2
|
|
assert curriculum_b.lesson_num == 2
|
|
|
|
|
|
@patch('mlagents.trainers.Curriculum')
|
|
@patch('mlagents.trainers.Curriculum')
|
|
def test_get_config(curriculum_a, curriculum_b, default_reset_parameters,
|
|
more_reset_parameters):
|
|
curriculum_a.get_config.return_value = default_reset_parameters
|
|
curriculum_b.get_config.return_value = default_reset_parameters
|
|
meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a,
|
|
'Brain2' : curriculum_b})
|
|
|
|
assert meta_curriculum.get_config() == default_reset_parameters
|
|
|
|
curriculum_b.get_config.return_value = more_reset_parameters
|
|
|
|
new_reset_parameters = dict(default_reset_parameters)
|
|
new_reset_parameters.update(more_reset_parameters)
|
|
|
|
assert meta_curriculum.get_config() == new_reset_parameters
|