浏览代码

Finished testing School. Added documentation.

/develop-generalizationTraining-TrainerController
Deric Pang 6 年前
当前提交
10ab5965
共有 2 个文件被更改,包括 113 次插入16 次删除
  1. 79
      python/tests/test_school.py
  2. 50
      python/unitytrainers/school.py

79
python/tests/test_school.py


import pytest
from unittest.mock import patch, call
from unittest.mock import patch, call, Mock
from unitytrainers import School
from unitytrainers.school import School
class SchoolTest(School):
"""This class allows us to test School objects without calling School's
__init__ function.
"""
def __init__(self, brains_to_curriculums):
self._brains_to_curriculums = brains_to_curriculums
return {"param1": 1, "param2": 1, "param3": 1}
return {'param1' : 1, 'param2' : 2, 'param3' : 3}
@pytest.fixture
def more_reset_parameters():
return {'param4' : 4, 'param5' : 5, 'param6' : 6}
@pytest.fixture
def progresses():
return {'TestBrain1' : 0.2, 'TestBrain2' : 0.3}
def test_init_school_happy_path(listdir, mock_curriculum, default_reset_parameters):
def test_init_school_happy_path(listdir, mock_curriculum_init, default_reset_parameters):
school = School('test-school/', default_reset_parameters)
assert len(school.brains_to_curriculums) == 2

calls = [call('test-school/TestBrain1.json', default_reset_parameters), call('test-school/TestBrain2.json', default_reset_parameters)]
mock_curriculum.assert_has_calls(calls)
mock_curriculum_init.assert_has_calls(calls)
@patch('unitytrainers.Curriculum')
@patch('unitytrainers.Curriculum')
def test_set_lesson_nums(test_brain_1_curriculum, test_brain_2_curriculum):
school = SchoolTest({'TestBrain1' : test_brain_1_curriculum, 'TestBrain2' : test_brain_2_curriculum})
school.lesson_nums = {'TestBrain1' : 1, 'TestBrain2' : 3}
assert test_brain_1_curriculum.lesson_num == 1
assert test_brain_2_curriculum.lesson_num == 3
@patch('unitytrainers.Curriculum')
@patch('unitytrainers.Curriculum')
def test_increment_lessons(test_brain_1_curriculum, test_brain_2_curriculum, progresses):
school = SchoolTest({'TestBrain1' : test_brain_1_curriculum, 'TestBrain2' : test_brain_2_curriculum})
school.increment_lessons(progresses)
test_brain_1_curriculum.increment_lesson.assert_called_with(0.2)
test_brain_2_curriculum.increment_lesson.assert_called_with(0.3)
@patch('unitytrainers.Curriculum')
@patch('unitytrainers.Curriculum')
def test_set_all_curriculums_to_lesson_num(test_brain_1_curriculum, test_brain_2_curriculum):
school = SchoolTest({'TestBrain1' : test_brain_1_curriculum, 'TestBrain2' : test_brain_2_curriculum})
school.set_all_curriculums_to_lesson_num(2)
assert test_brain_1_curriculum.lesson_num == 2
assert test_brain_2_curriculum.lesson_num == 2
@patch('unitytrainers.Curriculum')
@patch('unitytrainers.Curriculum')
def test_get_config(test_brain_1_curriculum, test_brain_2_curriculum, default_reset_parameters, more_reset_parameters):
test_brain_1_curriculum.get_config.return_value = default_reset_parameters
test_brain_2_curriculum.get_config.return_value = default_reset_parameters
school = SchoolTest({'TestBrain1' : test_brain_1_curriculum, 'TestBrain2' : test_brain_2_curriculum})
assert school.get_config() == default_reset_parameters
test_brain_2_curriculum.get_config.return_value = more_reset_parameters
new_reset_parameters = dict(default_reset_parameters)
new_reset_parameters.update(more_reset_parameters)
assert school.get_config() == new_reset_parameters

50
python/unitytrainers/school.py


"""
A School holds many curriculums. The School tracks which brains are following which curriculums.
"""
"""Contains the School class."""
class School:
class School(object):
"""A School holds curriculums. Each curriculum is associated to a particular
brain in the environment.
"""
"""
Initializes a School object.
"""Initializes a School object.
Args:
curriculum_folder (str): The relative or absolute path of the
folder which holds the curriculums for this environment.
The folder should contain JSON files whose names are the
brains that the curriculums belong to.
default_reset_parameters (dict): The default reset parameters
of the environment.
"""
if curriculum_folder is None:
self._brains_to_curriculums = None

brain_name = curriculum_filename.split('.')[0]
curriculum_filepath = os.path.join(curriculum_folder, curriculum_filename)
self._brains_to_curriculums[brain_name] = Curriculum(curriculum_filepath, default_reset_parameters)
curriculum_filepath = \
os.path.join(curriculum_folder, curriculum_filename)
self._brains_to_curriculums[brain_name] = \
Curriculum(curriculum_filepath, default_reset_parameters)
"""A dict from brain_name to the brain's curriculum."""
"""A dict from brain name to the brain's curriculum's lesson number."""
lesson_nums = {}
for brain_name, curriculum in self.brains_to_curriculums:
lesson_nums[brain_name] = curriculum.lesson_num

self.brains_to_curriculums[brain_name].lesson_num = lesson
def increment_lessons(self, progresses):
"""Increments all the lessons of all the curriculums in this School.
Args:
progresses (dict): A dict of brain name to progress.
"""
"""Sets all the curriculums in this school to a specified lesson number.
Args:
lesson_num (int): The lesson number which all the curriculums will
be set to.
"""
"""Get the combined configuration of all curriculums in this School.
Returns:
A dict from parameter to value.
"""
parameters = curriculum.data["parameters"]
for key in parameters:
config[key] = parameters[key][curriculum.lesson_num]
curr_config = curriculum.get_config()
config.update(curr_config)
return config
正在加载...
取消
保存