浏览代码
ObservableAttribute (#3925)
ObservableAttribute (#3925)
* ObservableAttribute proof-of-concept * restructue sensors, add int impl, unit test * add vector3 sensor, cleanup constructors * add more types * account for observables in barracuda checks * iterators for observable fields/props * stacking, fix obs size in prefab * use DeclaredOnly to filter members * ignore write-only properties * fix error message * docstrings * agent enum (WIP) * agent enum and unit tests * fix comment * cleanup TODO * ignore by default, rename declaredOnly param, docstrings * fix tests * rename, cleanup, revert FoodCollector * warning for write-only, no exception for invalid type * move observableAttributeHandling to BehaviorParameters * autoformatting * changelog * fix up sensor creation logic/docs-update
GitHub
5 年前
当前提交
197cf3e7
共有 32 个文件被更改,包括 1153 次插入 和 36 次删除
-
2com.unity.ml-agents/CHANGELOG.md
-
23com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
-
18com.unity.ml-agents/Runtime/Agent.cs
-
43com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
-
43com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
-
6com.unity.ml-agents/Runtime/Policies/BrainParameters.cs
-
24com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs
-
65com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
-
10com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
-
8com.unity.ml-agents/Runtime/Sensors/Reflection.meta
-
11com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta
-
292com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs
-
11com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta
-
19com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta
-
19com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta
-
19com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta
-
295com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta
-
22com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta
-
97com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta
-
20com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta
-
21com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta
-
22com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta
|
|||
fileFormatVersion: 2 |
|||
guid: 08ece3d7e9bb94089a9d59c6f269ab0a |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: e5e4df2934c014aa3b835b9eb9ad20b3 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using NUnit.Framework; |
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
using Unity.MLAgents.Sensors.Reflection; |
|||
|
|||
namespace Unity.MLAgents.Tests |
|||
{ |
|||
[TestFixture] |
|||
public class ObservableAttributeTests |
|||
{ |
|||
class TestClass |
|||
{ |
|||
// Non-observables
|
|||
int m_NonObservableInt; |
|||
float m_NonObservableFloat; |
|||
|
|||
//
|
|||
// Int
|
|||
//
|
|||
[Observable] |
|||
public int m_IntMember; |
|||
|
|||
int m_IntProperty; |
|||
|
|||
[Observable] |
|||
public int IntProperty |
|||
{ |
|||
get => m_IntProperty; |
|||
set => m_IntProperty = value; |
|||
} |
|||
|
|||
//
|
|||
// Float
|
|||
//
|
|||
[Observable("floatMember")] |
|||
public float m_FloatMember; |
|||
|
|||
float m_FloatProperty; |
|||
[Observable("floatProperty")] |
|||
public float FloatProperty |
|||
{ |
|||
get => m_FloatProperty; |
|||
set => m_FloatProperty = value; |
|||
} |
|||
|
|||
//
|
|||
// Bool
|
|||
//
|
|||
[Observable("boolMember")] |
|||
public bool m_BoolMember; |
|||
|
|||
bool m_BoolProperty; |
|||
[Observable("boolProperty")] |
|||
public bool BoolProperty |
|||
{ |
|||
get => m_BoolProperty; |
|||
set => m_BoolProperty = value; |
|||
} |
|||
|
|||
//
|
|||
// Vector2
|
|||
//
|
|||
|
|||
[Observable("vector2Member")] |
|||
public Vector2 m_Vector2Member; |
|||
|
|||
Vector2 m_Vector2Property; |
|||
|
|||
[Observable("vector2Property")] |
|||
public Vector2 Vector2Property |
|||
{ |
|||
get => m_Vector2Property; |
|||
set => m_Vector2Property = value; |
|||
} |
|||
|
|||
//
|
|||
// Vector3
|
|||
//
|
|||
[Observable("vector3Member")] |
|||
public Vector3 m_Vector3Member; |
|||
|
|||
Vector3 m_Vector3Property; |
|||
|
|||
[Observable("vector3Property")] |
|||
public Vector3 Vector3Property |
|||
{ |
|||
get => m_Vector3Property; |
|||
set => m_Vector3Property = value; |
|||
} |
|||
|
|||
//
|
|||
// Vector4
|
|||
//
|
|||
|
|||
[Observable("vector4Member")] |
|||
public Vector4 m_Vector4Member; |
|||
|
|||
Vector4 m_Vector4Property; |
|||
|
|||
[Observable("vector4Property")] |
|||
public Vector4 Vector4Property |
|||
{ |
|||
get => m_Vector4Property; |
|||
set => m_Vector4Property = value; |
|||
} |
|||
|
|||
//
|
|||
// Quaternion
|
|||
//
|
|||
[Observable("quaternionMember")] |
|||
public Quaternion m_QuaternionMember; |
|||
|
|||
Quaternion m_QuaternionProperty; |
|||
|
|||
[Observable("quaternionProperty")] |
|||
public Quaternion QuaternionProperty |
|||
{ |
|||
get => m_QuaternionProperty; |
|||
set => m_QuaternionProperty = value; |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestGetObservableSensors() |
|||
{ |
|||
var testClass = new TestClass(); |
|||
testClass.m_IntMember = 1; |
|||
testClass.IntProperty = 2; |
|||
|
|||
testClass.m_FloatMember = 1.1f; |
|||
testClass.FloatProperty = 1.2f; |
|||
|
|||
testClass.m_BoolMember = true; |
|||
testClass.BoolProperty = true; |
|||
|
|||
testClass.m_Vector2Member = new Vector2(2.0f, 2.1f); |
|||
testClass.Vector2Property = new Vector2(2.2f, 2.3f); |
|||
|
|||
testClass.m_Vector3Member = new Vector3(3.0f, 3.1f, 3.2f); |
|||
testClass.Vector3Property = new Vector3(3.3f, 3.4f, 3.5f); |
|||
|
|||
testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); |
|||
testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); |
|||
|
|||
testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); |
|||
testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); |
|||
|
|||
testClass.m_QuaternionMember = new Quaternion(5.0f, 5.1f, 5.2f, 5.3f); |
|||
testClass.QuaternionProperty = new Quaternion(5.4f, 5.5f, 5.5f, 5.7f); |
|||
|
|||
var sensors = ObservableAttribute.CreateObservableSensors(testClass, false); |
|||
|
|||
var sensorsByName = new Dictionary<string, ISensor>(); |
|||
foreach (var sensor in sensors) |
|||
{ |
|||
sensorsByName[sensor.GetName()] = sensor; |
|||
} |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] { 1.0f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] { 2.0f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["floatMember"], new[] { 1.1f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["floatProperty"], new[] { 1.2f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["boolMember"], new[] { 1.0f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["boolProperty"], new[] { 1.0f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["vector2Member"], new[] { 2.0f, 2.1f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["vector2Property"], new[] { 2.2f, 2.3f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["vector3Member"], new[] { 3.0f, 3.1f, 3.2f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["vector3Property"], new[] { 3.3f, 3.4f, 3.5f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["vector4Member"], new[] { 4.0f, 4.1f, 4.2f, 4.3f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["vector4Property"], new[] { 4.4f, 4.5f, 4.5f, 4.7f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] { 5.0f, 5.1f, 5.2f, 5.3f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] { 5.4f, 5.5f, 5.5f, 5.7f }); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestGetTotalObservationSize() |
|||
{ |
|||
var testClass = new TestClass(); |
|||
var errors = new List<string>(); |
|||
var expectedObsSize = 2 * (1 + 1 + 1 + 2 + 3 + 4 + 4); |
|||
Assert.AreEqual(expectedObsSize, ObservableAttribute.GetTotalObservationSize(testClass, false, errors)); |
|||
Assert.AreEqual(0, errors.Count); |
|||
} |
|||
|
|||
class BadClass |
|||
{ |
|||
[Observable] |
|||
double m_Double; |
|||
|
|||
[Observable] |
|||
double DoubleProperty |
|||
{ |
|||
get => m_Double; |
|||
set => m_Double = value; |
|||
} |
|||
|
|||
float m_WriteOnlyProperty; |
|||
|
|||
[Observable] |
|||
// No get property, so we shouldn't be able to make a sensor out of this.
|
|||
public float WriteOnlyProperty |
|||
{ |
|||
set => m_WriteOnlyProperty = value; |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestInvalidObservables() |
|||
{ |
|||
var bad = new BadClass(); |
|||
bad.WriteOnlyProperty = 1.0f; |
|||
var errors = new List<string>(); |
|||
Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, false, errors)); |
|||
Assert.AreEqual(3, errors.Count); |
|||
|
|||
// Should be able to safely generate sensors (and get nothing back)
|
|||
var sensors = ObservableAttribute.CreateObservableSensors(bad, false); |
|||
Assert.AreEqual(0, sensors.Count); |
|||
} |
|||
|
|||
class StackingClass |
|||
{ |
|||
[Observable(numStackedObservations: 2)] |
|||
public float FloatVal; |
|||
} |
|||
|
|||
[Test] |
|||
public void TestObservableAttributeStacking() |
|||
{ |
|||
var c = new StackingClass(); |
|||
c.FloatVal = 1.0f; |
|||
var sensors = ObservableAttribute.CreateObservableSensors(c, false); |
|||
var sensor = sensors[0]; |
|||
Assert.AreEqual(typeof(StackingSensor), sensor.GetType()); |
|||
SensorTestHelper.CompareObservation(sensor, new[] { 0.0f, 1.0f }); |
|||
|
|||
sensor.Update(); |
|||
c.FloatVal = 3.0f; |
|||
SensorTestHelper.CompareObservation(sensor, new[] { 1.0f, 3.0f }); |
|||
|
|||
var errors = new List<string>(); |
|||
Assert.AreEqual(2, ObservableAttribute.GetTotalObservationSize(c, false, errors)); |
|||
Assert.AreEqual(0, errors.Count); |
|||
} |
|||
|
|||
class BaseClass |
|||
{ |
|||
[Observable("base")] |
|||
public float m_BaseField; |
|||
|
|||
[Observable("private")] |
|||
float m_PrivateField; |
|||
} |
|||
|
|||
class DerivedClass : BaseClass |
|||
{ |
|||
[Observable("derived")] |
|||
float m_DerivedField; |
|||
} |
|||
|
|||
[Test] |
|||
public void TestObservableAttributeExcludeInherited() |
|||
{ |
|||
var d = new DerivedClass(); |
|||
d.m_BaseField = 1.0f; |
|||
|
|||
// excludeInherited=false will get fields in the derived class, plus public and protected inherited fields
|
|||
var sensorAll = ObservableAttribute.CreateObservableSensors(d, false); |
|||
Assert.AreEqual(2, sensorAll.Count); |
|||
// Note - actual order doesn't matter here, we can change this to use a HashSet if neeed.
|
|||
Assert.AreEqual("derived", sensorAll[0].GetName()); |
|||
Assert.AreEqual("base", sensorAll[1].GetName()); |
|||
|
|||
// excludeInherited=true will only get fields in the derived class
|
|||
var sensorsDerivedOnly = ObservableAttribute.CreateObservableSensors(d, true); |
|||
Assert.AreEqual(1, sensorsDerivedOnly.Count); |
|||
Assert.AreEqual("derived", sensorsDerivedOnly[0].GetName()); |
|||
|
|||
var b = new BaseClass(); |
|||
var baseSensors = ObservableAttribute.CreateObservableSensors(b, false); |
|||
Assert.AreEqual(2, baseSensors.Count); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 33d7912e6b3504412bd261b40e46df32 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a boolean field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class BoolReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
internal BoolReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 1) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var boolVal = (System.Boolean)GetReflectedValue(); |
|||
writer[0] = boolVal ? 1.0f : 0.0f; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: be795c90750a6420d93f569b69ddc1ba |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a float field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class FloatReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
internal FloatReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 1) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var floatVal = (System.Single)GetReflectedValue(); |
|||
writer[0] = floatVal; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 51ed837d5b7cd44349287ac8066120fc |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps an integer field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class IntReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
internal IntReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 1) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var intVal = (System.Int32)GetReflectedValue(); |
|||
writer[0] = (float)intVal; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 5cae4c843cc074d11a549aaa3904c898 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using System.Reflection; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Specify that a field or property should be used to generate observations for an Agent.
|
|||
/// For each field or property that uses ObservableAttribute, a corresponding
|
|||
/// <see cref="ISensor"/> will be created during Agent initialization, and this
|
|||
/// sensor will read the values during training and inference.
|
|||
/// </summary>
|
|||
/// <remarks>
|
|||
/// ObservableAttribute is intended to make initial setup of an Agent easier. Because it
|
|||
/// uses reflection to read the values of fields and properties at runtime, this may
|
|||
/// be much slower than reading the values directly. If the performance of
|
|||
/// ObservableAttribute is an issue, you can get the same functionality by overriding
|
|||
/// <see cref="Agent.CollectObservations(VectorSensor)"/> or creating a custom
|
|||
/// <see cref="ISensor"/> implementation to read the values without reflection.
|
|||
///
|
|||
/// Note that you do not need to adjust the VectorObservationSize in
|
|||
/// <see cref="Unity.MLAgents.Policies.BrainParameters"/> when adding ObservableAttribute
|
|||
/// to fields or properties.
|
|||
/// </remarks>
|
|||
/// <example>
|
|||
/// This sample class will produce two observations, one for the m_Health field, and one
|
|||
/// for the HealthPercent property.
|
|||
/// <code>
|
|||
/// using Unity.MLAgents;
|
|||
/// using Unity.MLAgents.Sensors.Reflection;
|
|||
///
|
|||
/// public class MyAgent : Agent
|
|||
/// {
|
|||
/// [Observable]
|
|||
/// int m_Health;
|
|||
///
|
|||
/// [Observable]
|
|||
/// float HealthPercent
|
|||
/// {
|
|||
/// get => return 100.0f * m_Health / float(m_MaxHealth);
|
|||
/// }
|
|||
/// }
|
|||
/// </code>
|
|||
/// </example>
|
|||
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] |
|||
public class ObservableAttribute : Attribute |
|||
{ |
|||
string m_Name; |
|||
int m_NumStackedObservations; |
|||
|
|||
/// <summary>
|
|||
/// Default binding flags used for reflection of members and properties.
|
|||
/// </summary>
|
|||
const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; |
|||
|
|||
/// <summary>
|
|||
/// Supported types and their observation sizes.
|
|||
/// </summary>
|
|||
static Dictionary<Type, int> s_TypeSizes = new Dictionary<Type, int>() |
|||
{ |
|||
{typeof(int), 1}, |
|||
{typeof(bool), 1}, |
|||
{typeof(float), 1}, |
|||
{typeof(Vector2), 2}, |
|||
{typeof(Vector3), 3}, |
|||
{typeof(Vector4), 4}, |
|||
{typeof(Quaternion), 4}, |
|||
}; |
|||
|
|||
/// <summary>
|
|||
/// ObservableAttribute constructor.
|
|||
/// </summary>
|
|||
/// <param name="name">Optional override for the sensor name. Note that all sensors for an Agent
|
|||
/// must have a unique name.</param>
|
|||
/// <param name="numStackedObservations">Number of frames to concatenate observations from.</param>
|
|||
public ObservableAttribute(string name = null, int numStackedObservations = 1) |
|||
{ |
|||
m_Name = name; |
|||
m_NumStackedObservations = numStackedObservations; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns a FieldInfo for all fields that have an ObservableAttribute
|
|||
/// </summary>
|
|||
/// <param name="o">Object being reflected</param>
|
|||
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
|
|||
/// <returns></returns>
|
|||
static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool excludeInherited) |
|||
{ |
|||
// TODO cache these (and properties) by type, so that we only have to reflect once.
|
|||
var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); |
|||
var fields = o.GetType().GetFields(bindingFlags); |
|||
foreach (var field in fields) |
|||
{ |
|||
var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute)); |
|||
if (attr != null) |
|||
{ |
|||
yield return (field, attr); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns a PropertyInfo for all fields that have an ObservableAttribute
|
|||
/// </summary>
|
|||
/// <param name="o">Object being reflected</param>
|
|||
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
|
|||
/// <returns></returns>
|
|||
static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool excludeInherited) |
|||
{ |
|||
var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); |
|||
var properties = o.GetType().GetProperties(bindingFlags); |
|||
foreach (var prop in properties) |
|||
{ |
|||
var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); |
|||
if (attr != null) |
|||
{ |
|||
yield return (prop, attr); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Creates sensors for each field and property with ObservableAttribute.
|
|||
/// </summary>
|
|||
/// <param name="o">Object being reflected</param>
|
|||
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
|
|||
/// <returns></returns>
|
|||
internal static List<ISensor> CreateObservableSensors(object o, bool excludeInherited) |
|||
{ |
|||
var sensorsOut = new List<ISensor>(); |
|||
foreach (var(field, attr) in GetObservableFields(o, excludeInherited)) |
|||
{ |
|||
var sensor = CreateReflectionSensor(o, field, null, attr); |
|||
if (sensor != null) |
|||
{ |
|||
sensorsOut.Add(sensor); |
|||
} |
|||
} |
|||
|
|||
foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited)) |
|||
{ |
|||
if (!prop.CanRead) |
|||
{ |
|||
// Skip unreadable properties.
|
|||
continue; |
|||
} |
|||
var sensor = CreateReflectionSensor(o, null, prop, attr); |
|||
if (sensor != null) |
|||
{ |
|||
sensorsOut.Add(sensor); |
|||
} |
|||
} |
|||
|
|||
return sensorsOut; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Create the ISensor for either the field or property on the provided object.
|
|||
/// If the data type is unsupported, or the property is write-only, returns null.
|
|||
/// </summary>
|
|||
/// <param name="o"></param>
|
|||
/// <param name="fieldInfo"></param>
|
|||
/// <param name="propertyInfo"></param>
|
|||
/// <param name="observableAttribute"></param>
|
|||
/// <returns></returns>
|
|||
/// <exception cref="UnityAgentsException"></exception>
|
|||
static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) |
|||
{ |
|||
string memberName; |
|||
string declaringTypeName; |
|||
Type memberType; |
|||
if (fieldInfo != null) |
|||
{ |
|||
declaringTypeName = fieldInfo.DeclaringType.Name; |
|||
memberName = fieldInfo.Name; |
|||
memberType = fieldInfo.FieldType; |
|||
} |
|||
else |
|||
{ |
|||
declaringTypeName = propertyInfo.DeclaringType.Name; |
|||
memberName = propertyInfo.Name; |
|||
memberType = propertyInfo.PropertyType; |
|||
} |
|||
|
|||
string sensorName; |
|||
if (string.IsNullOrEmpty(observableAttribute.m_Name)) |
|||
{ |
|||
sensorName = $"ObservableAttribute:{declaringTypeName}.{memberName}"; |
|||
} |
|||
else |
|||
{ |
|||
sensorName = observableAttribute.m_Name; |
|||
} |
|||
|
|||
var reflectionSensorInfo = new ReflectionSensorInfo |
|||
{ |
|||
Object = o, |
|||
FieldInfo = fieldInfo, |
|||
PropertyInfo = propertyInfo, |
|||
ObservableAttribute = observableAttribute, |
|||
SensorName = sensorName |
|||
}; |
|||
|
|||
ISensor sensor = null; |
|||
if (memberType == typeof(Int32)) |
|||
{ |
|||
sensor = new IntReflectionSensor(reflectionSensorInfo); |
|||
} |
|||
else if (memberType == typeof(float)) |
|||
{ |
|||
sensor = new FloatReflectionSensor(reflectionSensorInfo); |
|||
} |
|||
else if (memberType == typeof(bool)) |
|||
{ |
|||
sensor = new BoolReflectionSensor(reflectionSensorInfo); |
|||
} |
|||
else if (memberType == typeof(Vector2)) |
|||
{ |
|||
sensor = new Vector2ReflectionSensor(reflectionSensorInfo); |
|||
} |
|||
else if (memberType == typeof(Vector3)) |
|||
{ |
|||
sensor = new Vector3ReflectionSensor(reflectionSensorInfo); |
|||
} |
|||
else if (memberType == typeof(Vector4)) |
|||
{ |
|||
sensor = new Vector4ReflectionSensor(reflectionSensorInfo); |
|||
} |
|||
else if (memberType == typeof(Quaternion)) |
|||
{ |
|||
sensor = new QuaternionReflectionSensor(reflectionSensorInfo); |
|||
} |
|||
else |
|||
{ |
|||
// For unsupported types, return null and we'll filter them out later.
|
|||
return null; |
|||
} |
|||
|
|||
// Wrap the base sensor in a StackingSensor if we're using stacking.
|
|||
if (observableAttribute.m_NumStackedObservations > 1) |
|||
{ |
|||
return new StackingSensor(sensor, observableAttribute.m_NumStackedObservations); |
|||
} |
|||
|
|||
return sensor; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Gets the sum of the observation sizes of the Observable fields and properties on an object.
|
|||
/// Also appends errors to the errorsOut array.
|
|||
/// </summary>
|
|||
/// <param name="o"></param>
|
|||
/// <param name="excludeInherited"></param>
|
|||
/// <param name="errorsOut"></param>
|
|||
/// <returns></returns>
|
|||
internal static int GetTotalObservationSize(object o, bool excludeInherited, List<string> errorsOut) |
|||
{ |
|||
int sizeOut = 0; |
|||
foreach (var(field, attr) in GetObservableFields(o, excludeInherited)) |
|||
{ |
|||
if (s_TypeSizes.ContainsKey(field.FieldType)) |
|||
{ |
|||
sizeOut += s_TypeSizes[field.FieldType] * attr.m_NumStackedObservations; |
|||
} |
|||
else |
|||
{ |
|||
errorsOut.Add($"Unsupported Observable type {field.FieldType.Name} on field {field.Name}"); |
|||
} |
|||
} |
|||
|
|||
foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited)) |
|||
{ |
|||
if (s_TypeSizes.ContainsKey(prop.PropertyType)) |
|||
{ |
|||
if (prop.CanRead) |
|||
{ |
|||
sizeOut += s_TypeSizes[prop.PropertyType] * attr.m_NumStackedObservations; |
|||
} |
|||
else |
|||
{ |
|||
errorsOut.Add($"Observable property {prop.Name} is write-only."); |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on property {prop.Name}"); |
|||
} |
|||
} |
|||
|
|||
return sizeOut; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: a75086dc66a594baea6b8b2935f5dacf |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a quaternion field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class QuaternionReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
internal QuaternionReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 4) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var quatVal = (UnityEngine.Quaternion)GetReflectedValue(); |
|||
writer[0] = quatVal.x; |
|||
writer[1] = quatVal.y; |
|||
writer[2] = quatVal.z; |
|||
writer[3] = quatVal.w; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: d38241d74074d459bb4590f7f5d16c80 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System.Reflection; |
|||
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Construction info for a ReflectionSensorBase.
|
|||
/// </summary>
|
|||
internal struct ReflectionSensorInfo |
|||
{ |
|||
public object Object; |
|||
|
|||
public FieldInfo FieldInfo; |
|||
public PropertyInfo PropertyInfo; |
|||
public ObservableAttribute ObservableAttribute; |
|||
public string SensorName; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Abstract base class for reflection-based sensors.
|
|||
/// </summary>
|
|||
internal abstract class ReflectionSensorBase : ISensor |
|||
{ |
|||
protected object m_Object; |
|||
|
|||
// Exactly one of m_FieldInfo and m_PropertyInfo should be non-null.
|
|||
protected FieldInfo m_FieldInfo; |
|||
protected PropertyInfo m_PropertyInfo; |
|||
|
|||
// Not currently used, but might want later.
|
|||
protected ObservableAttribute m_ObservableAttribute; |
|||
|
|||
// Cached sensor names and shapes.
|
|||
string m_SensorName; |
|||
int[] m_Shape; |
|||
|
|||
public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size) |
|||
{ |
|||
m_Object = reflectionSensorInfo.Object; |
|||
m_FieldInfo = reflectionSensorInfo.FieldInfo; |
|||
m_PropertyInfo = reflectionSensorInfo.PropertyInfo; |
|||
m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute; |
|||
m_SensorName = reflectionSensorInfo.SensorName; |
|||
m_Shape = new[] {size}; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public int[] GetObservationShape() |
|||
{ |
|||
return m_Shape; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public int Write(ObservationWriter writer) |
|||
{ |
|||
WriteReflectedField(writer); |
|||
return m_Shape[0]; |
|||
} |
|||
|
|||
internal abstract void WriteReflectedField(ObservationWriter writer); |
|||
|
|||
/// <summary>
|
|||
/// Get either the reflected field, or return the reflected property.
|
|||
/// This should be used by implementations in their WriteReflectedField() method.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
protected object GetReflectedValue() |
|||
{ |
|||
return m_FieldInfo != null ? |
|||
m_FieldInfo.GetValue(m_Object) : |
|||
m_PropertyInfo.GetMethod.Invoke(m_Object, null); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public byte[] GetCompressedObservation() |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Update() {} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Reset() {} |
|||
|
|||
/// <inheritdoc/>
|
|||
public SensorCompressionType GetCompressionType() |
|||
{ |
|||
return SensorCompressionType.None; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public string GetName() |
|||
{ |
|||
return m_SensorName; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 6b68d855fb94a45fbbeb0dbe968a35f8 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a Vector2 field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class Vector2ReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
internal Vector2ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 2) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var vecVal = (UnityEngine.Vector2)GetReflectedValue(); |
|||
writer[0] = vecVal.x; |
|||
writer[1] = vecVal.y; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: da06ff33f6f2d409cbf240cffa2ba0be |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a Vector3 field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class Vector3ReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
internal Vector3ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 3) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var vecVal = (UnityEngine.Vector3)GetReflectedValue(); |
|||
writer[0] = vecVal.x; |
|||
writer[1] = vecVal.y; |
|||
writer[2] = vecVal.z; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: e756976ec2a0943cfbc0f97a6550a85b |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a Vector4 field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class Vector4ReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
internal Vector4ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 4) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var vecVal = (UnityEngine.Vector4)GetReflectedValue(); |
|||
writer[0] = vecVal.x; |
|||
writer[1] = vecVal.y; |
|||
writer[2] = vecVal.z; |
|||
writer[3] = vecVal.w; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 01d93aaa1b42b47b8960d303d7c498d3 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
撰写
预览
正在加载...
取消
保存
Reference in new issue