您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
104 行
4.3 KiB
104 行
4.3 KiB
import json
|
|
|
|
from unityagents.exception import UnityEnvironmentException
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger("unityagents")
|
|
|
|
|
|
class Curriculum(object):
|
|
def __init__(self, location, default_reset_parameters):
|
|
"""
|
|
Initializes a Curriculum object.
|
|
:param location: Path to JSON defining curriculum.
|
|
:param default_reset_parameters: Set of reset parameters for environment.
|
|
"""
|
|
self.lesson_length = 0
|
|
self.max_lesson_number = 0
|
|
self.measure_type = None
|
|
if location is None:
|
|
self.data = None
|
|
else:
|
|
try:
|
|
with open(location) as data_file:
|
|
self.data = json.load(data_file)
|
|
except IOError:
|
|
raise UnityEnvironmentException(
|
|
"The file {0} could not be found.".format(location))
|
|
except UnicodeDecodeError:
|
|
raise UnityEnvironmentException("There was an error decoding {}".format(location))
|
|
self.smoothing_value = 0
|
|
for key in ['parameters', 'measure', 'thresholds',
|
|
'min_lesson_length', 'signal_smoothing']:
|
|
if key not in self.data:
|
|
raise UnityEnvironmentException("{0} does not contain a "
|
|
"{1} field.".format(location, key))
|
|
parameters = self.data['parameters']
|
|
self.measure_type = self.data['measure']
|
|
self.max_lesson_number = len(self.data['thresholds'])
|
|
for key in parameters:
|
|
if key not in default_reset_parameters:
|
|
raise UnityEnvironmentException(
|
|
"The parameter {0} in Curriculum {1} is not present in "
|
|
"the Environment".format(key, location))
|
|
for key in parameters:
|
|
if len(parameters[key]) != self.max_lesson_number + 1:
|
|
raise UnityEnvironmentException(
|
|
"The parameter {0} in Curriculum {1} must have {2} values "
|
|
"but {3} were found".format(key, location,
|
|
self.max_lesson_number + 1, len(parameters[key])))
|
|
self.set_lesson_number(0)
|
|
|
|
@property
|
|
def measure(self):
|
|
return self.measure_type
|
|
|
|
@property
|
|
def get_lesson_number(self):
|
|
return self.lesson_number
|
|
|
|
def set_lesson_number(self, value):
|
|
self.lesson_length = 0
|
|
self.lesson_number = max(0, min(value, self.max_lesson_number))
|
|
|
|
def increment_lesson(self, progress):
|
|
"""
|
|
Increments the lesson number depending on the progress given.
|
|
:param progress: Measure of progress (either reward or percentage steps completed).
|
|
"""
|
|
if self.data is None or progress is None:
|
|
return
|
|
if self.data["signal_smoothing"]:
|
|
progress = self.smoothing_value * 0.25 + 0.75 * progress
|
|
self.smoothing_value = progress
|
|
self.lesson_length += 1
|
|
if self.lesson_number < self.max_lesson_number:
|
|
if ((progress > self.data['thresholds'][self.lesson_number]) and
|
|
(self.lesson_length > self.data['min_lesson_length'])):
|
|
self.lesson_length = 0
|
|
self.lesson_number += 1
|
|
config = {}
|
|
parameters = self.data["parameters"]
|
|
for key in parameters:
|
|
config[key] = parameters[key][self.lesson_number]
|
|
logger.info("\nLesson changed. Now in Lesson {0} : \t{1}"
|
|
.format(self.lesson_number,
|
|
', '.join([str(x) + ' -> ' + str(config[x]) for x in config])))
|
|
|
|
def get_config(self, lesson=None):
|
|
"""
|
|
Returns reset parameters which correspond to the lesson.
|
|
:param lesson: The lesson you want to get the config of. If None, the current lesson is returned.
|
|
:return: The configuration of the reset parameters.
|
|
"""
|
|
if self.data is None:
|
|
return {}
|
|
if lesson is None:
|
|
lesson = self.lesson_number
|
|
lesson = max(0, min(lesson, self.max_lesson_number))
|
|
config = {}
|
|
parameters = self.data["parameters"]
|
|
for key in parameters:
|
|
config[key] = parameters[key][lesson]
|
|
return config
|