vincentpierre
7 年前
当前提交
3b00302a
共有 30 个文件被更改,包括 1032 次插入 和 278 次删除
-
23docs/Making-a-new-Unity-Environment.md
-
35python/PPO.ipynb
-
35python/ppo.py
-
9python/ppo/models.py
-
3python/ppo/trainer.py
-
7python/test_unityagents.py
-
1python/unityagents/__init__.py
-
3python/unityagents/brain.py
-
117python/unityagents/environment.py
-
3unity-environment/Assets/ML-Agents/Examples/3DBall/Scripts/Ball3DDecision.cs
-
21unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicDecision.cs
-
6unity-environment/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
-
242unity-environment/Assets/ML-Agents/Examples/Tennis/Tennis.unity
-
8unity-environment/Assets/ML-Agents/Scripts/Academy.cs
-
18unity-environment/Assets/ML-Agents/Scripts/Agent.cs
-
35unity-environment/Assets/ML-Agents/Scripts/Brain.cs
-
7unity-environment/Assets/ML-Agents/Scripts/Communicator.cs
-
32unity-environment/Assets/ML-Agents/Scripts/CoreBrainExternal.cs
-
18unity-environment/Assets/ML-Agents/Scripts/CoreBrainHeuristic.cs
-
51unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs
-
19unity-environment/Assets/ML-Agents/Scripts/CoreBrainPlayer.cs
-
102unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs
-
19unity-environment/Assets/ML-Agents/Template/Scripts/TemplateDecision.cs
-
12python/curriculum.json
-
71python/unityagents/curriculum.py
-
380unity-environment/Assets/ML-Agents/Scripts/Monitor.cs
-
12unity-environment/Assets/ML-Agents/Scripts/Monitor.cs.meta
-
12unity-environment/Assets/ML-Agents/Scripts/AgentMonitor.cs.meta
-
9unity-environment/Assets/ML-Agents/Resources.meta
|
|||
from .environment import * |
|||
from .brain import * |
|||
from .exception import * |
|||
from .curriculum import * |
|
|||
{ |
|||
"measure" : "reward", |
|||
"thresholds" : [10, 20, 50], |
|||
"min_lesson_length" : 3, |
|||
"signal_smoothing" : true, |
|||
"parameters" : |
|||
{ |
|||
"param1" : [0.7, 0.5, 0.3, 0.1], |
|||
"param2" : [100, 50, 20, 15], |
|||
"param3" : [0.2, 0.3, 0.7, 0.9] |
|||
} |
|||
} |
|
|||
import json |
|||
import numpy as np |
|||
|
|||
from .exception import UnityEnvironmentException |
|||
|
|||
|
|||
class Curriculum(object): |
|||
def __init__(self, location, default_reset_parameters): |
|||
self.lesson_number = 0 |
|||
self.lesson_length = 0 |
|||
self.measure_type = None |
|||
if location is None: |
|||
self.data = None |
|||
else: |
|||
try: |
|||
with open(location) as data_file: |
|||
self.data = json.load(data_file) |
|||
except FileNotFoundError: |
|||
raise UnityEnvironmentException( |
|||
"The file {0} could not be found.".format(location)) |
|||
except UnicodeDecodeError: |
|||
raise UnityEnvironmentException("There was an error decoding {}".format(location)) |
|||
self.smoothing_value = 0 |
|||
for key in ['parameters', 'measure', 'thresholds', |
|||
'min_lesson_length', 'signal_smoothing']: |
|||
if key not in self.data: |
|||
raise UnityEnvironmentException("{0} does not contain a " |
|||
"{1} field.".format(location, key)) |
|||
parameters = self.data['parameters'] |
|||
self.measure_type = self.data['measure'] |
|||
self.max_lesson_number = len(self.data['thresholds']) |
|||
for key in parameters: |
|||
if key not in default_reset_parameters: |
|||
raise UnityEnvironmentException( |
|||
"The parameter {0} in Curriculum {1} is not present in " |
|||
"the Environment".format(key, location)) |
|||
for key in parameters: |
|||
if len(parameters[key]) != self.max_lesson_number + 1: |
|||
raise UnityEnvironmentException( |
|||
"The parameter {0} in Curriculum {1} must have {2} values " |
|||
"but {3} were found".format(key, location, |
|||
self.max_lesson_number + 1, len(parameters[key]))) |
|||
|
|||
@property |
|||
def measure(self): |
|||
return self.measure_type |
|||
|
|||
def get_lesson_number(self): |
|||
return self.lesson_number |
|||
|
|||
def set_lesson_number(self, value): |
|||
self.lesson_length = 0 |
|||
self.lesson_number = max(0, min(value, self.max_lesson_number)) |
|||
|
|||
def get_lesson(self, progress): |
|||
if self.data is None or progress is None: |
|||
return {} |
|||
if self.data["signal_smoothing"]: |
|||
progress = self.smoothing_value * 0.9 + 0.1 * progress |
|||
self.smoothing_value = progress |
|||
self.lesson_length += 1 |
|||
if self.lesson_number < self.max_lesson_number: |
|||
if ((progress > self.data['thresholds'][self.lesson_number]) and |
|||
(self.lesson_length > self.data['min_lesson_length'])): |
|||
self.lesson_length = 0 |
|||
self.lesson_number += 1 |
|||
config = {} |
|||
parameters = self.data["parameters"] |
|||
for key in parameters: |
|||
config[key] = parameters[key][self.lesson_number] |
|||
return config |
|
|||
using System.Collections; |
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
using UnityEngine.UI; |
|||
using Newtonsoft.Json; |
|||
using System.Linq; |
|||
|
|||
|
|||
/** The type of monitor the information must be displayed in. |
|||
* <slider> corresponds to a slingle rectangle which width is given |
|||
* by a float between -1 and 1. (green is positive, red is negative) |
|||
* <hist> corresponds to n vertical sliders. |
|||
* <text> is a text field. |
|||
* <bar> is a rectangle of fixed length to represent the proportions |
|||
* of a list of floats. |
|||
*/ |
|||
public enum MonitorType |
|||
{ |
|||
slider, |
|||
hist, |
|||
text, |
|||
bar |
|||
} |
|||
|
|||
/** Monitor is used to display information. Use the log function to add |
|||
* information to your monitor. |
|||
*/ |
|||
public class Monitor : MonoBehaviour |
|||
{ |
|||
|
|||
static bool isInstanciated; |
|||
static GameObject canvas; |
|||
|
|||
private struct DisplayValue |
|||
{ |
|||
public float time; |
|||
public object value; |
|||
public MonitorType monitorDisplayType; |
|||
} |
|||
|
|||
static Dictionary<Transform, Dictionary<string, DisplayValue>> displayTransformValues; |
|||
static private Color[] barColors; |
|||
[HideInInspector] |
|||
static public float verticalOffset = 3f; |
|||
/**< \brief This float represents how high above the target the monitors will be. */ |
|||
|
|||
static GUIStyle keyStyle; |
|||
static GUIStyle valueStyle; |
|||
static GUIStyle greenStyle; |
|||
static GUIStyle redStyle; |
|||
static GUIStyle[] colorStyle; |
|||
static bool initialized; |
|||
|
|||
|
|||
/** Use the Monitor.Log static function to attach information to a transform. |
|||
* If displayType is <text>, value can be any object. |
|||
* If sidplayType is <slider>, value must be a float. |
|||
* If sidplayType is <hist>, value must be a List or Array of floats. |
|||
* If sidplayType is <bar>, value must be a list or Array of positive floats. |
|||
* Note that <slider> and <hist> caps values between -1 and 1. |
|||
* @param key The name of the information you wish to Log. |
|||
* @param value The value you want to display. |
|||
* @param displayType The type of display. |
|||
* @param target The transform you want to attach the information to. |
|||
*/ |
|||
public static void Log( |
|||
string key, |
|||
object value, |
|||
MonitorType displayType = MonitorType.text, |
|||
Transform target = null) |
|||
{ |
|||
|
|||
|
|||
|
|||
if (!isInstanciated) |
|||
{ |
|||
InstanciateCanvas(); |
|||
isInstanciated = true; |
|||
|
|||
} |
|||
|
|||
if (target == null) |
|||
{ |
|||
target = canvas.transform; |
|||
} |
|||
|
|||
if (!displayTransformValues.Keys.Contains(target)) |
|||
{ |
|||
displayTransformValues[target] = new Dictionary<string, DisplayValue>(); |
|||
} |
|||
|
|||
Dictionary<string, DisplayValue> displayValues = displayTransformValues[target]; |
|||
|
|||
if (value == null) |
|||
{ |
|||
RemoveValue(target, key); |
|||
return; |
|||
} |
|||
if (!displayValues.ContainsKey(key)) |
|||
{ |
|||
DisplayValue dv = new DisplayValue(); |
|||
dv.time = Time.timeSinceLevelLoad; |
|||
dv.value = value; |
|||
dv.monitorDisplayType = displayType; |
|||
displayValues[key] = dv; |
|||
while (displayValues.Count > 20) |
|||
{ |
|||
string max = displayValues.Aggregate((l, r) => l.Value.time < r.Value.time ? l : r).Key; |
|||
RemoveValue(target, max); |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
DisplayValue dv = displayValues[key]; |
|||
dv.value = value; |
|||
displayValues[key] = dv; |
|||
} |
|||
} |
|||
|
|||
/** Remove a value from a monitor |
|||
* @param target The transform to which the information is attached |
|||
* @param key The key of the information you want to remove |
|||
*/ |
|||
public static void RemoveValue(Transform target, string key) |
|||
{ |
|||
if (target == null) |
|||
{ |
|||
target = canvas.transform; |
|||
} |
|||
if (displayTransformValues.Keys.Contains(target)) |
|||
{ |
|||
if (displayTransformValues[target].ContainsKey(key)) |
|||
{ |
|||
displayTransformValues[target].Remove(key); |
|||
if (displayTransformValues[target].Keys.Count == 0) |
|||
{ |
|||
displayTransformValues.Remove(target); |
|||
} |
|||
} |
|||
} |
|||
|
|||
} |
|||
|
|||
/** Remove all information from a monitor |
|||
* @param target The transform to which the information is attached |
|||
*/ |
|||
public static void RemoveAllValues(Transform target) |
|||
{ |
|||
if (target == null) |
|||
{ |
|||
target = canvas.transform; |
|||
} |
|||
if (displayTransformValues.Keys.Contains(target)) |
|||
{ |
|||
displayTransformValues.Remove(target); |
|||
} |
|||
|
|||
} |
|||
|
|||
/** Use SetActive to enable or disable the Monitor via script |
|||
* @param active Set the Monitor's status to the value of active |
|||
*/ |
|||
public static void SetActive(bool active){ |
|||
if (!isInstanciated) |
|||
{ |
|||
InstanciateCanvas(); |
|||
isInstanciated = true; |
|||
|
|||
} |
|||
canvas.SetActive(active); |
|||
|
|||
} |
|||
|
|||
private static void InstanciateCanvas() |
|||
{ |
|||
canvas = GameObject.Find("AgentMonitorCanvas"); |
|||
if (canvas == null) |
|||
{ |
|||
canvas = new GameObject(); |
|||
canvas.name = "AgentMonitorCanvas"; |
|||
canvas.AddComponent<Monitor>(); |
|||
} |
|||
displayTransformValues = new Dictionary<Transform, Dictionary< string , DisplayValue>>(); |
|||
|
|||
} |
|||
|
|||
private float[] ToFloatArray(object input) |
|||
{ |
|||
try |
|||
{ |
|||
return JsonConvert.DeserializeObject<float[]>( |
|||
JsonConvert.SerializeObject(input, Formatting.None)); |
|||
} |
|||
catch |
|||
{ |
|||
} |
|||
try |
|||
{ |
|||
return new float[1] |
|||
{JsonConvert.DeserializeObject<float>( |
|||
JsonConvert.SerializeObject(input, Formatting.None)) |
|||
}; |
|||
} |
|||
catch |
|||
{ |
|||
} |
|||
|
|||
return new float[0]; |
|||
} |
|||
|
|||
void OnGUI() |
|||
{ |
|||
if (!initialized) |
|||
{ |
|||
Initialize(); |
|||
initialized = true; |
|||
} |
|||
|
|||
var toIterate = displayTransformValues.Keys.ToList(); |
|||
foreach (Transform target in toIterate) |
|||
{ |
|||
if (target == null) |
|||
{ |
|||
displayTransformValues.Remove(target); |
|||
continue; |
|||
} |
|||
|
|||
float widthScaler = (Screen.width / 1000f); |
|||
float keyPixelWidth = 100 * widthScaler; |
|||
float keyPixelHeight = 20 * widthScaler; |
|||
float paddingwidth = 10 * widthScaler; |
|||
|
|||
float scale = 1f; |
|||
Vector2 origin = new Vector3(0, Screen.height); |
|||
if (!(target == canvas.transform)) |
|||
{ |
|||
Vector3 cam2obj = target.position - Camera.main.transform.position; |
|||
scale = Mathf.Min(1, 20f / (Vector3.Dot(cam2obj, Camera.main.transform.forward))); |
|||
Vector3 worldPosition = Camera.main.WorldToScreenPoint(target.position + new Vector3(0, verticalOffset, 0)); |
|||
origin = new Vector3(worldPosition.x - keyPixelWidth * scale, Screen.height - worldPosition.y); |
|||
} |
|||
keyPixelWidth *= scale; |
|||
keyPixelHeight *= scale; |
|||
paddingwidth *= scale; |
|||
keyStyle.fontSize = (int)(keyPixelHeight * 0.8f); |
|||
if (keyStyle.fontSize < 2) |
|||
{ |
|||
continue; |
|||
} |
|||
|
|||
|
|||
Dictionary<string, DisplayValue> displayValues = displayTransformValues[target]; |
|||
|
|||
int index = 0; |
|||
foreach (string key in displayValues.Keys.OrderBy(x => -displayValues[x].time)) |
|||
{ |
|||
keyStyle.alignment = TextAnchor.MiddleRight; |
|||
GUI.Label(new Rect(origin.x, origin.y - (index + 1) * keyPixelHeight, keyPixelWidth, keyPixelHeight), key, keyStyle); |
|||
if (displayValues[key].monitorDisplayType == MonitorType.text) |
|||
{ |
|||
valueStyle.alignment = TextAnchor.MiddleLeft; |
|||
GUI.Label(new Rect( |
|||
origin.x + paddingwidth + keyPixelWidth, |
|||
origin.y - (index + 1) * keyPixelHeight, |
|||
keyPixelWidth, keyPixelHeight), |
|||
JsonConvert.SerializeObject(displayValues[key].value, Formatting.None), valueStyle); |
|||
|
|||
} |
|||
else if (displayValues[key].monitorDisplayType == MonitorType.slider) |
|||
{ |
|||
float sliderValue = 0f; |
|||
if (displayValues[key].value.GetType() == typeof(float)) |
|||
{ |
|||
sliderValue = (float)displayValues[key].value; |
|||
} |
|||
else |
|||
{ |
|||
Debug.LogError(string.Format("The value for {0} could not be displayed as " + |
|||
"a slider because it is not a number.", key)); |
|||
} |
|||
|
|||
sliderValue = Mathf.Min(1f, sliderValue); |
|||
GUIStyle s = greenStyle; |
|||
if (sliderValue < 0) |
|||
{ |
|||
sliderValue = Mathf.Min(1f, -sliderValue); |
|||
s = redStyle; |
|||
} |
|||
GUI.Box(new Rect( |
|||
origin.x + paddingwidth + keyPixelWidth, |
|||
origin.y - (index + 0.9f) * keyPixelHeight, |
|||
keyPixelWidth * sliderValue, keyPixelHeight * 0.8f), |
|||
GUIContent.none, s); |
|||
|
|||
} |
|||
else if (displayValues[key].monitorDisplayType == MonitorType.hist) |
|||
{ |
|||
float histWidth = 0.15f; |
|||
float[] vals = ToFloatArray(displayValues[key].value); |
|||
for (int i = 0; i < vals.Length; i++) |
|||
{ |
|||
float value = Mathf.Min(vals[i], 1); |
|||
GUIStyle s = greenStyle; |
|||
if (value < 0) |
|||
{ |
|||
value = Mathf.Min(1f, -value); |
|||
s = redStyle; |
|||
} |
|||
GUI.Box(new Rect( |
|||
origin.x + paddingwidth + keyPixelWidth + (keyPixelWidth * histWidth + paddingwidth / 2) * i, |
|||
origin.y - (index + 0.1f) * keyPixelHeight, |
|||
keyPixelWidth * histWidth, -keyPixelHeight * value), |
|||
GUIContent.none, s); |
|||
} |
|||
|
|||
|
|||
} |
|||
else if (displayValues[key].monitorDisplayType == MonitorType.bar) |
|||
{ |
|||
float[] vals = ToFloatArray(displayValues[key].value); |
|||
float valsSum = 0f; |
|||
float valsCum = 0f; |
|||
foreach (float f in vals) |
|||
{ |
|||
valsSum += Mathf.Max(f, 0); |
|||
} |
|||
if (valsSum == 0) |
|||
{ |
|||
Debug.LogError(string.Format("The Monitor value for key {0} must be " |
|||
+ "a list or array of positive values and cannot be empty.", key)); |
|||
} |
|||
else |
|||
{ |
|||
for (int i = 0; i < vals.Length; i++) |
|||
{ |
|||
float value = Mathf.Max(vals[i], 0) / valsSum; |
|||
GUI.Box(new Rect( |
|||
origin.x + paddingwidth + keyPixelWidth + keyPixelWidth * valsCum, |
|||
origin.y - (index + 0.9f) * keyPixelHeight, |
|||
keyPixelWidth * value, keyPixelHeight * 0.8f), |
|||
GUIContent.none, colorStyle[i % colorStyle.Length]); |
|||
valsCum += value; |
|||
|
|||
} |
|||
|
|||
} |
|||
|
|||
} |
|||
|
|||
index++; |
|||
} |
|||
} |
|||
} |
|||
|
|||
private void Initialize() |
|||
{ |
|||
|
|||
keyStyle = GUI.skin.label; |
|||
valueStyle = GUI.skin.label; |
|||
valueStyle.clipping = TextClipping.Overflow; |
|||
valueStyle.wordWrap = false; |
|||
|
|||
|
|||
|
|||
barColors = new Color[6]{ Color.magenta, Color.blue, Color.cyan, Color.green, Color.yellow, Color.red }; |
|||
colorStyle = new GUIStyle[barColors.Length]; |
|||
for (int i = 0; i < barColors.Length; i++) |
|||
{ |
|||
Texture2D texture = new Texture2D(1, 1, TextureFormat.ARGB32, false); |
|||
texture.SetPixel(0, 0, barColors[i]); |
|||
texture.Apply(); |
|||
GUIStyle staticRectStyle = new GUIStyle(); |
|||
staticRectStyle.normal.background = texture; |
|||
colorStyle[i] = staticRectStyle; |
|||
} |
|||
greenStyle = colorStyle[3]; |
|||
redStyle = colorStyle[5]; |
|||
} |
|||
|
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: e59a31a1cc2f5464d9a61bef0bc9a53b |
|||
timeCreated: 1508031727 |
|||
licenseType: Free |
|||
MonoImporter: |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: e040eaa8759024abbbb14994dc4c55ee |
|||
timeCreated: 1502056030 |
|||
licenseType: Free |
|||
MonoImporter: |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 10f3eff160a3b46fcb86042594151eae |
|||
folderAsset: yes |
|||
timeCreated: 1501551323 |
|||
licenseType: Free |
|||
DefaultImporter: |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
撰写
预览
正在加载...
取消
保存
Reference in new issue