import pytest from unittest.mock import patch, call, Mock from mlagents.trainers.meta_curriculum import MetaCurriculum from mlagents.trainers.exception import MetaCurriculumError class MetaCurriculumTest(MetaCurriculum): """This class allows us to test MetaCurriculum objects without calling MetaCurriculum's __init__ function. """ def __init__(self, brains_to_curriculums): self._brains_to_curriculums = brains_to_curriculums @pytest.fixture def default_reset_parameters(): return {'param1' : 1, 'param2' : 2, 'param3' : 3} @pytest.fixture def more_reset_parameters(): return {'param4' : 4, 'param5' : 5, 'param6' : 6} @pytest.fixture def progresses(): return {'Brain1' : 0.2, 'Brain2' : 0.3} @patch('mlagents.trainers.Curriculum.get_config', return_value={}) @patch('mlagents.trainers.Curriculum.__init__', return_value=None) @patch('os.listdir', return_value=['Brain1.json', 'Brain2.json']) def test_init_meta_curriculum_happy_path(listdir, mock_curriculum_init, mock_curriculum_get_config, default_reset_parameters): meta_curriculum = MetaCurriculum('test/', default_reset_parameters) assert len(meta_curriculum.brains_to_curriculums) == 2 assert 'Brain1' in meta_curriculum.brains_to_curriculums assert 'Brain2' in meta_curriculum.brains_to_curriculums calls = [call('test/Brain1.json', default_reset_parameters), call('test/Brain2.json', default_reset_parameters)] mock_curriculum_init.assert_has_calls(calls) @patch('os.listdir', side_effect=NotADirectoryError()) def test_init_meta_curriculum_bad_curriculum_folder_raises_error(listdir): with pytest.raises(MetaCurriculumError): MetaCurriculum('test/', default_reset_parameters) @patch('mlagents.trainers.Curriculum') @patch('mlagents.trainers.Curriculum') def test_set_lesson_nums(curriculum_a, curriculum_b): meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a, 'Brain2' : curriculum_b}) meta_curriculum.lesson_nums = {'Brain1' : 1, 'Brain2' : 3} assert curriculum_a.lesson_num == 1 assert curriculum_b.lesson_num == 3 @patch('mlagents.trainers.Curriculum') @patch('mlagents.trainers.Curriculum') def test_increment_lessons(curriculum_a, curriculum_b, progresses): meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a, 'Brain2' : curriculum_b}) meta_curriculum.increment_lessons(progresses) curriculum_a.increment_lesson.assert_called_with(0.2) curriculum_b.increment_lesson.assert_called_with(0.3) @patch('mlagents.trainers.Curriculum') @patch('mlagents.trainers.Curriculum') def test_set_all_curriculums_to_lesson_num(curriculum_a, curriculum_b): meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a, 'Brain2' : curriculum_b}) meta_curriculum.set_all_curriculums_to_lesson_num(2) assert curriculum_a.lesson_num == 2 assert curriculum_b.lesson_num == 2 @patch('mlagents.trainers.Curriculum') @patch('mlagents.trainers.Curriculum') def test_get_config(curriculum_a, curriculum_b, default_reset_parameters, more_reset_parameters): curriculum_a.get_config.return_value = default_reset_parameters curriculum_b.get_config.return_value = default_reset_parameters meta_curriculum = MetaCurriculumTest({'Brain1' : curriculum_a, 'Brain2' : curriculum_b}) assert meta_curriculum.get_config() == default_reset_parameters curriculum_b.get_config.return_value = more_reset_parameters new_reset_parameters = dict(default_reset_parameters) new_reset_parameters.update(more_reset_parameters) assert meta_curriculum.get_config() == new_reset_parameters