浏览代码

Curriculums now hold the brain name.

/develop-generalizationTraining-TrainerController
Deric Pang 7 年前
当前提交
361d56b9
共有 2 个文件被更改,包括 36 次插入33 次删除
  1. 1
      python/tests/test_curriculum.py
  2. 68
      python/unitytrainers/curriculum.py

1
python/tests/test_curriculum.py


def test_init_curriculum_happy_path(mock_file, location, default_reset_parameters):
curriculum = Curriculum(location, default_reset_parameters)
assert curriculum._brain_name == 'TestBrain'
assert curriculum.lesson_num == 0
assert curriculum.measure == 'reward'

68
python/unitytrainers/curriculum.py


import os
import json
from .exception import CurriculumError

self.max_lesson_num = 0
self.measure = None
self._lesson_num = 0
# The name of the brain should be the basename of the file without the
# extension.
self._brain_name = os.path.basename(location).split('.')[0]
if location is None:
self.data = None
else:
try:
with open(location) as data_file:
self.data = json.load(data_file)
except IOError:
raise CurriculumError(
'The file {0} could not be found.'.format(location))
except UnicodeDecodeError:
raise CurriculumError('There was an error decoding {}'.format(location))
self.smoothing_value = 0
for key in ['parameters', 'measure', 'thresholds',
'min_lesson_length', 'signal_smoothing']:
if key not in self.data:
raise CurriculumError("{0} does not contain a "
"{1} field.".format(location, key))
self.smoothing_value = 0
self.measure = self.data['measure']
self.max_lesson_num = len(self.data['thresholds'])
try:
with open(location) as data_file:
self.data = json.load(data_file)
except IOError:
raise CurriculumError(
'The file {0} could not be found.'.format(location))
except UnicodeDecodeError:
raise CurriculumError('There was an error decoding {}'.format(location))
self.smoothing_value = 0
for key in ['parameters', 'measure', 'thresholds',
'min_lesson_length', 'signal_smoothing']:
if key not in self.data:
raise CurriculumError("{0} does not contain a "
"{1} field.".format(location, key))
self.smoothing_value = 0
self.measure = self.data['measure']
self.max_lesson_num = len(self.data['thresholds'])
parameters = self.data['parameters']
for key in parameters:
if key not in default_reset_parameters:
raise CurriculumError(
'The parameter {0} in Curriculum {1} is not present in '
'the Environment'.format(key, location))
if len(parameters[key]) != self.max_lesson_num + 1:
raise CurriculumError(
'The parameter {0} in Curriculum {1} must have {2} values '
'but {3} were found'.format(key, location,
self.max_lesson_num + 1, len(parameters[key])))
parameters = self.data['parameters']
for key in parameters:
if key not in default_reset_parameters:
raise CurriculumError(
'The parameter {0} in Curriculum {1} is not present in '
'the Environment'.format(key, location))
if len(parameters[key]) != self.max_lesson_num + 1:
raise CurriculumError(
'The parameter {0} in Curriculum {1} must have {2} values '
'but {3} were found'.format(key, location,
self.max_lesson_num + 1, len(parameters[key])))
@property
def lesson_num(self):

parameters = self.data['parameters']
for key in parameters:
config[key] = parameters[key][self.lesson_num]
logger.info('\nLesson changed. Now in Lesson {0} : \t{1}'
.format(self.lesson_num,
logger.info('\n{0} lesson changed. Now in Lesson {1} : \t{2}'
.format(self._brain_name,
self.lesson_num,
', '.join([str(x) + ' -> ' + str(config[x]) for x in config])))
def get_config(self, lesson=None):

正在加载...
取消
保存