浏览代码
initial commit of the curriculum with broadcast. Improved the Unity python handshake
/develop-generalizationTraining-TrainerController
initial commit of the curriculum with broadcast. Improved the Unity python handshake
/develop-generalizationTraining-TrainerController
vincentpierre
7 年前
当前提交
ac910514
共有 5 个文件被更改,包括 91 次插入 和 6 次删除
-
1python/unityagents/__init__.py
-
15python/unityagents/environment.py
-
1unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs
-
12python/curriculum.json
-
68python/unityagents/curriculum.py
|
|||
from .environment import * |
|||
from .brain import * |
|||
from .exception import * |
|||
from .curriculum import * |
|
|||
{ |
|||
"measure" : "progress", |
|||
"thresholds" : [0.1, 0.2, 0.5], |
|||
"min_lesson_length" : 3, |
|||
"signal_smoothing" : true, |
|||
"parameters" : |
|||
{ |
|||
"param1" : [100, 50, 10, 1], |
|||
"param2" : [0.1, 0.1, 0.5, 0.01,0], |
|||
"param3" : [0.2, 0.3, 0.7, 0.9] |
|||
} |
|||
} |
|
|||
import json |
|||
import numpy as np |
|||
|
|||
from .exception import UnityEnvironmentException |
|||
|
|||
class Curriculum(object): |
|||
def __init__(self, location, default_reset_parameters): |
|||
self.lesson_number = 0 |
|||
self.lesson_length = 0 |
|||
self.measure_type = None |
|||
if location == None: |
|||
self.data = None |
|||
else: |
|||
try: |
|||
with open(location) as data_file: |
|||
self.data = json.load(data_file) |
|||
except: |
|||
raise UnityEnvironmentException( |
|||
"The file {0} could not be found.".format(location)) |
|||
parameters = self.data["parameters"] |
|||
self.measure_type = self.data["measure"] |
|||
self.max_lesson_number = len(self.data['thresholds']) |
|||
self.smoothing_value = 0 |
|||
for key in parameters: |
|||
if key not in default_reset_parameters: |
|||
raise UnityEnvironmentException( |
|||
"The parameter {0} in Curriculum {1} is not present in " |
|||
"the Environment".format(key, location)) |
|||
for key in parameters: |
|||
if len(parameters[key]) != self.max_lesson_number + 1: |
|||
raise UnityEnvironmentException( |
|||
"The parameter {0} in Curriculum {1} must have {2} values " |
|||
"but {3} were found".format(key, location, |
|||
self.max_lesson_number + 1, len(parameters[key]))) |
|||
|
|||
|
|||
@property |
|||
def measure(self): |
|||
return self.measure_type |
|||
|
|||
def get_lesson_number(self): |
|||
return self.lesson_number |
|||
|
|||
def set_lesson_number(self, value): |
|||
self.lesson_length = 0 |
|||
self.lesson_number = max(0,min(value,self.max_lesson_number)) |
|||
|
|||
def get_lesson(self, progress): |
|||
if (self.data == None ) or (progress == None): |
|||
return {} |
|||
if self.data["signal_smoothing"]: |
|||
progress = self.smoothing_value*0.9 + 0.1*progress |
|||
self.smoothing_value = progress |
|||
self.lesson_length += 1 |
|||
if self.lesson_number < self.max_lesson_number: |
|||
if ((progress > self.data['thresholds'][self.lesson_number]) |
|||
and (self.lesson_length > self.data['min_lesson_length'])): |
|||
self.lesson_length = 0 |
|||
self.lesson_number += 1 |
|||
config = {} |
|||
parameters = self.data["parameters"] |
|||
for key in parameters: |
|||
config[key] = parameters[key][self.lesson_number] |
|||
return config |
|||
|
|||
|
|||
|
|||
|
撰写
预览
正在加载...
取消
保存
Reference in new issue