浏览代码

Made the code cleanup changes - mostly nit changes

/develop-generalizationTraining-TrainerController
sankalp04 5 年前
当前提交
db858686
共有 1 个文件被更改,包括 16 次插入14 次删除
  1. 30
      ml-agents/mlagents/trainers/lesson_controller.py

30
ml-agents/mlagents/trainers/lesson_controller.py


import os
import logging
import logging
class LessonController(object):
class LessonController:
def __init__(self, location):
"""
Initializes a Curriculum object.

try:
with open(location) as data_file:
self.data = yaml.load(data_file)
data = yaml.load(data_file)
self.check_keys(location)
self.measure = self.data["measure"]
self.thresholds = self.data["thresholds"]
self.min_lesson_length = self.data["min_lesson_length"]
self.check_keys(data, location)
self.measure = data["measure"]
self.thresholds = data["thresholds"]
self.min_lesson_length = data["min_lesson_length"]
self.signal_smoothing = data["signal_smoothing"]
self.test_lesson_length = (data["test_lesson_length"]
if "test_lesson_length" in data
else 1000)
def check_keys(self, location):
def check_keys(self, data, location):
for key in [
"measure",
"thresholds",

if key not in self.data:
if key not in data:
raise LessonControllerError(
"{0} does not contain a " "{1} field.".format(location, key)
)

:param measure_val: A dict of brain name to measure value.
:return Whether the lesson was incremented.
"""
if not self.data or not measure_val or math.isnan(measure_val):
if (not measure_val) or math.isnan(measure_val):
if self.data["signal_smoothing"]:
if self.signal_smoothing:
if measure_val >= self.data["thresholds"][self.lesson_num]:
if measure_val >= self.thresholds[self.lesson_num]:
self.lesson_num += 1
logger.info(
"Lesson changed. Now in lesson {0}".format(

正在加载...
取消
保存