浏览代码

ObservableAttribute (#3925)

* ObservableAttribute proof-of-concept

* restructue sensors, add int impl, unit test

* add vector3 sensor, cleanup constructors

* add more types

* account for observables in barracuda checks

* iterators for observable fields/props

* stacking, fix obs size in prefab

* use DeclaredOnly to filter members

* ignore write-only properties

* fix error message

* docstrings

* agent enum (WIP)

* agent enum and unit tests

* fix comment

* cleanup TODO

* ignore by default, rename declaredOnly param, docstrings

* fix tests

* rename, cleanup, revert FoodCollector

* warning for write-only, no exception for invalid type

* move observableAttributeHandling to BehaviorParameters

* autoformatting

* changelog

* fix up sensor creation logic
/docs-update
GitHub 5 年前
当前提交
197cf3e7
共有 32 个文件被更改,包括 1153 次插入36 次删除
  1. 2
      com.unity.ml-agents/CHANGELOG.md
  2. 23
      com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
  3. 18
      com.unity.ml-agents/Runtime/Agent.cs
  4. 43
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  5. 43
      com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
  6. 6
      com.unity.ml-agents/Runtime/Policies/BrainParameters.cs
  7. 24
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs
  8. 65
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  9. 10
      com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
  10. 8
      com.unity.ml-agents/Runtime/Sensors/Reflection.meta
  11. 11
      com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta
  12. 292
      com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs
  13. 11
      com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta
  14. 19
      com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs
  15. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta
  16. 19
      com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs
  17. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta
  18. 19
      com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs
  19. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta
  20. 295
      com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs
  21. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta
  22. 22
      com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs
  23. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta
  24. 97
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  25. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta
  26. 20
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs
  27. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta
  28. 21
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs
  29. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta
  30. 22
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs
  31. 11
      com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta

2
com.unity.ml-agents/CHANGELOG.md


- `get_behavior_names()` and `get_behavior_spec()` on UnityEnvironment were replaced by the `behavior_specs` property. (#3946)
### Minor Changes
#### com.unity.ml-agents (C#)
- `ObservableAttribute` was added. Adding the attribute to fields or properties on an Agent will allow it to generate
observations via reflection.
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Curriculum and Parameter Randomization configurations have been merged
into the main training configuration file. Note that this means training

23
com.unity.ml-agents/Editor/BehaviorParametersEditor.cs


using System.Collections.Generic;
using Unity.MLAgents.Sensors.Reflection;
using UnityEngine;
namespace Unity.MLAgents.Editor

EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
EditorGUILayout.PropertyField(so.FindProperty("m_UseChildSensors"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservableAttributeHandling"), true);
}
EditorGUI.EndDisabledGroup();

Model barracudaModel = null;
var model = (NNModel)serializedObject.FindProperty("m_Model").objectReferenceValue;
var behaviorParameters = (BehaviorParameters)target;
// Grab the sensor components, since we need them to determine the observation sizes.
SensorComponent[] sensorComponents;
if (behaviorParameters.UseChildSensors)
{

{
sensorComponents = behaviorParameters.GetComponents<SensorComponent>();
}
// Get the total size of the sensors generated by ObservableAttributes.
// If there are any errors (e.g. unsupported type, write-only properties), display them too.
int observableAttributeSensorTotalSize = 0;
var agent = behaviorParameters.GetComponent<Agent>();
if (agent != null && behaviorParameters.ObservableAttributeHandling != ObservableAttributeOptions.Ignore)
{
List<string> observableErrors = new List<string>();
observableAttributeSensorTotalSize = ObservableAttribute.GetTotalObservationSize(agent, false, observableErrors);
foreach (var check in observableErrors)
{
EditorGUILayout.HelpBox(check, MessageType.Warning);
}
}
var brainParameters = behaviorParameters.BrainParameters;
if (model != null)
{

{
var failedChecks = Inference.BarracudaModelParamLoader.CheckModel(
barracudaModel, brainParameters, sensorComponents, behaviorParameters.BehaviorType
barracudaModel, brainParameters, sensorComponents,
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType
);
foreach (var check in failedChecks)
{

18
com.unity.ml-agents/Runtime/Agent.cs


using UnityEngine;
using Unity.Barracuda;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Policies;
using UnityEngine.Serialization;

m_Brain = m_PolicyFactory.GeneratePolicy(Heuristic);
ResetData();
Initialize();
InitializeSensors();
using (TimerStack.Instance.Scoped("InitializeSensors"))
{
InitializeSensors();
}
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.

/// </summary>
internal void InitializeSensors()
{
if (m_PolicyFactory.ObservableAttributeHandling != ObservableAttributeOptions.Ignore)
{
var excludeInherited =
m_PolicyFactory.ObservableAttributeHandling == ObservableAttributeOptions.ExcludeInherited;
using (TimerStack.Instance.Scoped("CreateObservableSensors"))
{
var observableSensors = ObservableAttribute.CreateObservableSensors(this, excludeInherited);
sensors.AddRange(observableSensors);
}
}
// Get all attached sensor components
SensorComponent[] attachedSensorComponents;
if (m_PolicyFactory.UseChildSensors)

43
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensorComponents">Attached sensor components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
SensorComponent[] sensorComponents, BehaviorType behaviorType = BehaviorType.Default)
SensorComponent[] sensorComponents, int observableAttributeTotalSize = 0,
BehaviorType behaviorType = BehaviorType.Default)
{
List<string> failedModelChecks = new List<string>();
if (model == null)

CheckOutputTensorPresence(model, memorySize))
;
failedModelChecks.AddRange(
CheckInputTensorShape(model, brainParameters, sensorComponents)
CheckInputTensorShape(model, brainParameters, sensorComponents, observableAttributeTotalSize)
);
failedModelChecks.AddRange(
CheckOutputTensorShape(model, brainParameters, isContinuous, actionSize)

/// Whether the model is expecting continuous or discrete control.
/// </param>
/// <param name="sensorComponents">Array of attached sensor components</param>
/// <param name="observableAttributeTotalSize">Total size of ObservableAttributes</param>
/// <returns>
/// A IEnumerable of string corresponding to the failed input presence checks.
/// </returns>

/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensorComponents">Attached sensors</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents)
Model model, BrainParameters brainParameters, SensorComponent[] sensorComponents,
int observableAttributeTotalSize)
new Dictionary<string, Func<BrainParameters, TensorProxy, SensorComponent[], string>>()
new Dictionary<string, Func<BrainParameters, TensorProxy, SensorComponent[], int, string>>()
{TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs) => null)},
{TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)},
{TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)},
tensorTester[mem.input] = ((bp, tensor, scs) => null);
tensorTester[mem.input] = ((bp, tensor, scs, i) => null);
}
var visObsIndex = 0;

continue;
}
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
(bp, tensor, scs) => CheckVisualObsShape(tensor, sensorComponent);
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent);
visObsIndex++;
}

else
{
var tester = tensorTester[tensor.name];
var error = tester.Invoke(brainParameters, tensor, sensorComponents);
var error = tester.Invoke(brainParameters, tensor, sensorComponents, observableAttributeTotalSize);
if (error != null)
{
failedModelChecks.Add(error);

/// </param>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensorComponents">Array of attached sensor components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents)
BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents,
int observableAttributeTotalSize)
{
var vecObsSizeBp = brainParameters.VectorObservationSize;
var numStackedVector = brainParameters.NumStackedVectorObservations;

totalVectorSensorSize += sensorComp.GetObservationShape()[0];
}
}
totalVectorSensorSize += observableAttributeTotalSize;
if (vecObsSizeBp * numStackedVector + totalVectorSensorSize != totalVecObsSizeT)
{

sensorSizes += "]";
return $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " +
$"but received {vecObsSizeBp} x {numStackedVector} vector observations and " +
$"but received: \n" +
$"Vector observations: {vecObsSizeBp} x {numStackedVector}\n" +
$"Total [Observable] attributes: {observableAttributeTotalSize}\n" +
$"SensorComponent sizes: {sensorSizes}.";
}
return null;

/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="tensorProxy"> The tensor that is expected by the model</param>
/// <param name="sensorComponents">Array of attached sensor components</param>
/// <param name="sensorComponents">Array of attached sensor components (unused).</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes (unused).</param>
BrainParameters brainParameters, TensorProxy tensorProxy, SensorComponent[] sensorComponents)
BrainParameters brainParameters, TensorProxy tensorProxy,
SensorComponent[] sensorComponents, int observableAttributeTotalSize)
{
var numberActionsBp = brainParameters.VectorActionSize.Length;
var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1];

43
com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs


using System;
using UnityEngine;
using UnityEngine.Serialization;
using Unity.MLAgents.Sensors.Reflection;
namespace Unity.MLAgents.Policies
{

/// neural network model.
/// </summary>
InferenceOnly
}
/// <summary>
/// Options for controlling how the Agent class is searched for <see cref="ObservableAttribute"/>s.
/// </summary>
public enum ObservableAttributeOptions
{
/// <summary>
/// All ObservableAttributes on the Agent will be ignored. If there are no
/// ObservableAttributes on the Agent, this will result in the fastest
/// initialization time.
/// </summary>
Ignore,
/// <summary>
/// Only members on the declared class will be examined; members that are
/// inherited are ignored. This is the default behavior, and a reasonable
/// tradeoff between performance and flexibility.
/// </summary>
/// <remarks>This corresponds to setting the
/// [BindingFlags.DeclaredOnly](https://docs.microsoft.com/en-us/dotnet/api/system.reflection.bindingflags?view=netcore-3.1)
/// when examining the fields and properties of the Agent class instance.
/// </remarks>
ExcludeInherited,
/// <summary>
/// All members on the class will be examined. This can lead to slower
/// startup times
/// </summary>
ExamineAll
}
/// <summary>

{
get { return m_UseChildSensors; }
set { m_UseChildSensors = value; }
}
[HideInInspector, SerializeField]
ObservableAttributeOptions m_ObservableAttributeHandling = ObservableAttributeOptions.Ignore;
/// <summary>
/// Determines how the Agent class is searched for <see cref="ObservableAttribute"/>s.
/// </summary>
public ObservableAttributeOptions ObservableAttributeHandling
{
get { return m_ObservableAttributeHandling; }
set { m_ObservableAttributeHandling = value; }
}
/// <summary>

6
com.unity.ml-agents/Runtime/Policies/BrainParameters.cs


public class BrainParameters
{
/// <summary>
/// The size of the observation space.
/// </summary>
/// <remarks>An agent creates the observation vector in its
/// The number of the observations that are added in
/// implementation.</remarks>
/// </summary>
/// <value>
/// The length of the vector containing observation values.
/// </value>

24
com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs


using UnityEngine;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors.Reflection;
namespace Unity.MLAgents.Tests
{

}
}
static List<TestAgent> GetFakeAgents()
static List<TestAgent> GetFakeAgents(ObservableAttributeOptions observableAttributeOptions = ObservableAttributeOptions.Ignore)
bpA.ObservableAttributeHandling = observableAttributeOptions;
var agentA = goA.AddComponent<TestAgent>();
var goB = new GameObject("goB");

bpB.ObservableAttributeHandling = observableAttributeOptions;
var agentB = goB.AddComponent<TestAgent>();
var agents = new List<TestAgent> { agentA, agentB };

{
var inputTensor = new TensorProxy
{
shape = new long[] { 2, 3 }
shape = new long[] { 2, 4 }
var agentInfos = GetFakeAgents();
var agentInfos = GetFakeAgents(ObservableAttributeOptions.ExamineAll);
generator.AddSensorIndex(0);
generator.AddSensorIndex(1);
generator.AddSensorIndex(2);
generator.AddSensorIndex(0); // ObservableAttribute (size 1)
generator.AddSensorIndex(1); // TestSensor (size 0)
generator.AddSensorIndex(2); // TestSensor (size 0)
generator.AddSensorIndex(3); // VectorSensor (size 3)
var agent0 = agentInfos[0];
var agent1 = agentInfos[1];
var inputs = new List<AgentInfoSensorsPair>

};
generator.Generate(inputTensor, batchSize, inputs);
Assert.IsNotNull(inputTensor.data);
Assert.AreEqual(inputTensor.data[0, 0], 1);
Assert.AreEqual(inputTensor.data[0, 2], 3);
Assert.AreEqual(inputTensor.data[1, 0], 4);
Assert.AreEqual(inputTensor.data[1, 2], 6);
Assert.AreEqual(inputTensor.data[0, 1], 1);
Assert.AreEqual(inputTensor.data[0, 3], 3);
Assert.AreEqual(inputTensor.data[1, 1], 4);
Assert.AreEqual(inputTensor.data[1, 3], 6);
alloc.Dispose();
}

65
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


using System.Reflection;
using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using Unity.MLAgents.Policies;
using Unity.MLAgents.SideChannels;

public int heuristicCalls;
public TestSensor sensor1;
public TestSensor sensor2;
[Observable("observableFloat")]
public float observableFloat;
public override void Initialize()
{

var agentGo1 = new GameObject("TestAgent");
agentGo1.AddComponent<TestAgent>();
var agent1 = agentGo1.GetComponent<TestAgent>();
var bp1 = agentGo1.GetComponent<BehaviorParameters>();
bp1.ObservableAttributeHandling = ObservableAttributeOptions.ExcludeInherited;
var agentGo2 = new GameObject("TestAgent");
agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();

Assert.AreEqual(0, agent2.agentActionCalls);
// Make sure the Sensors were sorted
Assert.AreEqual(agent1.sensors[0].GetName(), "testsensor1");
Assert.AreEqual(agent1.sensors[1].GetName(), "testsensor2");
Assert.AreEqual(agent1.sensors[0].GetName(), "observableFloat");
Assert.AreEqual(agent1.sensors[1].GetName(), "testsensor1");
Assert.AreEqual(agent1.sensors[2].GetName(), "testsensor2");
// agent2 should only have two sensors (no observableFloat)
Assert.AreEqual(agent2.sensors[0].GetName(), "testsensor1");
Assert.AreEqual(agent2.sensors[1].GetName(), "testsensor2");
}
}

public void TestAgentDontCallBaseOnEnable()
{
_InnerAgentTestOnEnableOverride();
}
}
[TestFixture]
public class ObservableAttributeBehaviorTests
{
public class BaseObservableAgent : Agent
{
[Observable]
public float BaseField;
}
public class DerivedObservableAgent : BaseObservableAgent
{
[Observable]
public float DerivedField;
}
[Test]
public void TestObservableAttributeBehaviorIgnore()
{
var variants = new[]
{
// No observables found
(ObservableAttributeOptions.Ignore, 0),
// Only DerivedField found
(ObservableAttributeOptions.ExcludeInherited, 1),
// DerivedField and BaseField found
(ObservableAttributeOptions.ExamineAll, 2)
};
foreach (var(behavior, expectedNumSensors) in variants)
{
var go = new GameObject();
var agent = go.AddComponent<DerivedObservableAgent>();
var bp = go.GetComponent<BehaviorParameters>();
bp.ObservableAttributeHandling = behavior;
agent.LazyInitialize();
int numAttributeSensors = 0;
foreach (var sensor in agent.sensors)
{
if (sensor.GetType() != typeof(VectorSensor))
{
numAttributeSensors++;
}
}
Assert.AreEqual(expectedNumSensors, numAttributeSensors);
}
}
}
}

10
com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs


#if UNITY_INCLUDE_TESTS
#if UNITY_INCLUDE_TESTS
using Unity.MLAgents.Sensors.Reflection;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.TestTools;

[Observable]
public float ObservableFloat;
public override void Heuristic(float[] actionsOut)
{
numHeuristicCalls++;

public override int[] GetObservationShape()
{
int[] shape = (int[]) wrappedComponent.GetObservationShape().Clone();
int[] shape = (int[])wrappedComponent.GetObservationShape().Clone();
for (var i = 0; i < shape.Length; i++)
{
shape[i] *= numStacks;

behaviorParams.BehaviorName = "TestBehavior";
behaviorParams.TeamId = 42;
behaviorParams.UseChildSensors = true;
behaviorParams.ObservableAttributeHandling = ObservableAttributeOptions.ExamineAll;
// Can't actually create an Agent with InferenceOnly and no model, so change back

8
com.unity.ml-agents/Runtime/Sensors/Reflection.meta


fileFormatVersion: 2
guid: 08ece3d7e9bb94089a9d59c6f269ab0a
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

11
com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta


fileFormatVersion: 2
guid: e5e4df2934c014aa3b835b9eb9ad20b3
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

292
com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs


using System;
using System.Collections.Generic;
using NUnit.Framework;
using UnityEngine;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class ObservableAttributeTests
{
class TestClass
{
// Non-observables
int m_NonObservableInt;
float m_NonObservableFloat;
//
// Int
//
[Observable]
public int m_IntMember;
int m_IntProperty;
[Observable]
public int IntProperty
{
get => m_IntProperty;
set => m_IntProperty = value;
}
//
// Float
//
[Observable("floatMember")]
public float m_FloatMember;
float m_FloatProperty;
[Observable("floatProperty")]
public float FloatProperty
{
get => m_FloatProperty;
set => m_FloatProperty = value;
}
//
// Bool
//
[Observable("boolMember")]
public bool m_BoolMember;
bool m_BoolProperty;
[Observable("boolProperty")]
public bool BoolProperty
{
get => m_BoolProperty;
set => m_BoolProperty = value;
}
//
// Vector2
//
[Observable("vector2Member")]
public Vector2 m_Vector2Member;
Vector2 m_Vector2Property;
[Observable("vector2Property")]
public Vector2 Vector2Property
{
get => m_Vector2Property;
set => m_Vector2Property = value;
}
//
// Vector3
//
[Observable("vector3Member")]
public Vector3 m_Vector3Member;
Vector3 m_Vector3Property;
[Observable("vector3Property")]
public Vector3 Vector3Property
{
get => m_Vector3Property;
set => m_Vector3Property = value;
}
//
// Vector4
//
[Observable("vector4Member")]
public Vector4 m_Vector4Member;
Vector4 m_Vector4Property;
[Observable("vector4Property")]
public Vector4 Vector4Property
{
get => m_Vector4Property;
set => m_Vector4Property = value;
}
//
// Quaternion
//
[Observable("quaternionMember")]
public Quaternion m_QuaternionMember;
Quaternion m_QuaternionProperty;
[Observable("quaternionProperty")]
public Quaternion QuaternionProperty
{
get => m_QuaternionProperty;
set => m_QuaternionProperty = value;
}
}
[Test]
public void TestGetObservableSensors()
{
var testClass = new TestClass();
testClass.m_IntMember = 1;
testClass.IntProperty = 2;
testClass.m_FloatMember = 1.1f;
testClass.FloatProperty = 1.2f;
testClass.m_BoolMember = true;
testClass.BoolProperty = true;
testClass.m_Vector2Member = new Vector2(2.0f, 2.1f);
testClass.Vector2Property = new Vector2(2.2f, 2.3f);
testClass.m_Vector3Member = new Vector3(3.0f, 3.1f, 3.2f);
testClass.Vector3Property = new Vector3(3.3f, 3.4f, 3.5f);
testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f);
testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f);
testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f);
testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f);
testClass.m_QuaternionMember = new Quaternion(5.0f, 5.1f, 5.2f, 5.3f);
testClass.QuaternionProperty = new Quaternion(5.4f, 5.5f, 5.5f, 5.7f);
var sensors = ObservableAttribute.CreateObservableSensors(testClass, false);
var sensorsByName = new Dictionary<string, ISensor>();
foreach (var sensor in sensors)
{
sensorsByName[sensor.GetName()] = sensor;
}
SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] { 1.0f });
SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] { 2.0f });
SensorTestHelper.CompareObservation(sensorsByName["floatMember"], new[] { 1.1f });
SensorTestHelper.CompareObservation(sensorsByName["floatProperty"], new[] { 1.2f });
SensorTestHelper.CompareObservation(sensorsByName["boolMember"], new[] { 1.0f });
SensorTestHelper.CompareObservation(sensorsByName["boolProperty"], new[] { 1.0f });
SensorTestHelper.CompareObservation(sensorsByName["vector2Member"], new[] { 2.0f, 2.1f });
SensorTestHelper.CompareObservation(sensorsByName["vector2Property"], new[] { 2.2f, 2.3f });
SensorTestHelper.CompareObservation(sensorsByName["vector3Member"], new[] { 3.0f, 3.1f, 3.2f });
SensorTestHelper.CompareObservation(sensorsByName["vector3Property"], new[] { 3.3f, 3.4f, 3.5f });
SensorTestHelper.CompareObservation(sensorsByName["vector4Member"], new[] { 4.0f, 4.1f, 4.2f, 4.3f });
SensorTestHelper.CompareObservation(sensorsByName["vector4Property"], new[] { 4.4f, 4.5f, 4.5f, 4.7f });
SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] { 5.0f, 5.1f, 5.2f, 5.3f });
SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] { 5.4f, 5.5f, 5.5f, 5.7f });
}
[Test]
public void TestGetTotalObservationSize()
{
var testClass = new TestClass();
var errors = new List<string>();
var expectedObsSize = 2 * (1 + 1 + 1 + 2 + 3 + 4 + 4);
Assert.AreEqual(expectedObsSize, ObservableAttribute.GetTotalObservationSize(testClass, false, errors));
Assert.AreEqual(0, errors.Count);
}
class BadClass
{
[Observable]
double m_Double;
[Observable]
double DoubleProperty
{
get => m_Double;
set => m_Double = value;
}
float m_WriteOnlyProperty;
[Observable]
// No get property, so we shouldn't be able to make a sensor out of this.
public float WriteOnlyProperty
{
set => m_WriteOnlyProperty = value;
}
}
[Test]
public void TestInvalidObservables()
{
var bad = new BadClass();
bad.WriteOnlyProperty = 1.0f;
var errors = new List<string>();
Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, false, errors));
Assert.AreEqual(3, errors.Count);
// Should be able to safely generate sensors (and get nothing back)
var sensors = ObservableAttribute.CreateObservableSensors(bad, false);
Assert.AreEqual(0, sensors.Count);
}
class StackingClass
{
[Observable(numStackedObservations: 2)]
public float FloatVal;
}
[Test]
public void TestObservableAttributeStacking()
{
var c = new StackingClass();
c.FloatVal = 1.0f;
var sensors = ObservableAttribute.CreateObservableSensors(c, false);
var sensor = sensors[0];
Assert.AreEqual(typeof(StackingSensor), sensor.GetType());
SensorTestHelper.CompareObservation(sensor, new[] { 0.0f, 1.0f });
sensor.Update();
c.FloatVal = 3.0f;
SensorTestHelper.CompareObservation(sensor, new[] { 1.0f, 3.0f });
var errors = new List<string>();
Assert.AreEqual(2, ObservableAttribute.GetTotalObservationSize(c, false, errors));
Assert.AreEqual(0, errors.Count);
}
class BaseClass
{
[Observable("base")]
public float m_BaseField;
[Observable("private")]
float m_PrivateField;
}
class DerivedClass : BaseClass
{
[Observable("derived")]
float m_DerivedField;
}
[Test]
public void TestObservableAttributeExcludeInherited()
{
var d = new DerivedClass();
d.m_BaseField = 1.0f;
// excludeInherited=false will get fields in the derived class, plus public and protected inherited fields
var sensorAll = ObservableAttribute.CreateObservableSensors(d, false);
Assert.AreEqual(2, sensorAll.Count);
// Note - actual order doesn't matter here, we can change this to use a HashSet if neeed.
Assert.AreEqual("derived", sensorAll[0].GetName());
Assert.AreEqual("base", sensorAll[1].GetName());
// excludeInherited=true will only get fields in the derived class
var sensorsDerivedOnly = ObservableAttribute.CreateObservableSensors(d, true);
Assert.AreEqual(1, sensorsDerivedOnly.Count);
Assert.AreEqual("derived", sensorsDerivedOnly[0].GetName());
var b = new BaseClass();
var baseSensors = ObservableAttribute.CreateObservableSensors(b, false);
Assert.AreEqual(2, baseSensors.Count);
}
}
}

11
com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta


fileFormatVersion: 2
guid: 33d7912e6b3504412bd261b40e46df32
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

19
com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a boolean field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class BoolReflectionSensor : ReflectionSensorBase
{
internal BoolReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 1)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var boolVal = (System.Boolean)GetReflectedValue();
writer[0] = boolVal ? 1.0f : 0.0f;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta


fileFormatVersion: 2
guid: be795c90750a6420d93f569b69ddc1ba
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

19
com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a float field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class FloatReflectionSensor : ReflectionSensorBase
{
internal FloatReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 1)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var floatVal = (System.Single)GetReflectedValue();
writer[0] = floatVal;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta


fileFormatVersion: 2
guid: 51ed837d5b7cd44349287ac8066120fc
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

19
com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps an integer field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class IntReflectionSensor : ReflectionSensorBase
{
internal IntReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 1)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var intVal = (System.Int32)GetReflectedValue();
writer[0] = (float)intVal;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta


fileFormatVersion: 2
guid: 5cae4c843cc074d11a549aaa3904c898
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

295
com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs


using System;
using System.Collections.Generic;
using System.Reflection;
using UnityEngine;
namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Specify that a field or property should be used to generate observations for an Agent.
/// For each field or property that uses ObservableAttribute, a corresponding
/// <see cref="ISensor"/> will be created during Agent initialization, and this
/// sensor will read the values during training and inference.
/// </summary>
/// <remarks>
/// ObservableAttribute is intended to make initial setup of an Agent easier. Because it
/// uses reflection to read the values of fields and properties at runtime, this may
/// be much slower than reading the values directly. If the performance of
/// ObservableAttribute is an issue, you can get the same functionality by overriding
/// <see cref="Agent.CollectObservations(VectorSensor)"/> or creating a custom
/// <see cref="ISensor"/> implementation to read the values without reflection.
///
/// Note that you do not need to adjust the VectorObservationSize in
/// <see cref="Unity.MLAgents.Policies.BrainParameters"/> when adding ObservableAttribute
/// to fields or properties.
/// </remarks>
/// <example>
/// This sample class will produce two observations, one for the m_Health field, and one
/// for the HealthPercent property.
/// <code>
/// using Unity.MLAgents;
/// using Unity.MLAgents.Sensors.Reflection;
///
/// public class MyAgent : Agent
/// {
/// [Observable]
/// int m_Health;
///
/// [Observable]
/// float HealthPercent
/// {
/// get => return 100.0f * m_Health / float(m_MaxHealth);
/// }
/// }
/// </code>
/// </example>
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)]
public class ObservableAttribute : Attribute
{
string m_Name;
int m_NumStackedObservations;
/// <summary>
/// Default binding flags used for reflection of members and properties.
/// </summary>
const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;
/// <summary>
/// Supported types and their observation sizes.
/// </summary>
static Dictionary<Type, int> s_TypeSizes = new Dictionary<Type, int>()
{
{typeof(int), 1},
{typeof(bool), 1},
{typeof(float), 1},
{typeof(Vector2), 2},
{typeof(Vector3), 3},
{typeof(Vector4), 4},
{typeof(Quaternion), 4},
};
/// <summary>
/// ObservableAttribute constructor.
/// </summary>
/// <param name="name">Optional override for the sensor name. Note that all sensors for an Agent
/// must have a unique name.</param>
/// <param name="numStackedObservations">Number of frames to concatenate observations from.</param>
public ObservableAttribute(string name = null, int numStackedObservations = 1)
{
m_Name = name;
m_NumStackedObservations = numStackedObservations;
}
/// <summary>
/// Returns a FieldInfo for all fields that have an ObservableAttribute
/// </summary>
/// <param name="o">Object being reflected</param>
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
/// <returns></returns>
static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool excludeInherited)
{
// TODO cache these (and properties) by type, so that we only have to reflect once.
var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0);
var fields = o.GetType().GetFields(bindingFlags);
foreach (var field in fields)
{
var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute));
if (attr != null)
{
yield return (field, attr);
}
}
}
/// <summary>
/// Returns a PropertyInfo for all fields that have an ObservableAttribute
/// </summary>
/// <param name="o">Object being reflected</param>
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
/// <returns></returns>
static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool excludeInherited)
{
var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0);
var properties = o.GetType().GetProperties(bindingFlags);
foreach (var prop in properties)
{
var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute));
if (attr != null)
{
yield return (prop, attr);
}
}
}
/// <summary>
/// Creates sensors for each field and property with ObservableAttribute.
/// </summary>
/// <param name="o">Object being reflected</param>
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
/// <returns></returns>
internal static List<ISensor> CreateObservableSensors(object o, bool excludeInherited)
{
var sensorsOut = new List<ISensor>();
foreach (var(field, attr) in GetObservableFields(o, excludeInherited))
{
var sensor = CreateReflectionSensor(o, field, null, attr);
if (sensor != null)
{
sensorsOut.Add(sensor);
}
}
foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited))
{
if (!prop.CanRead)
{
// Skip unreadable properties.
continue;
}
var sensor = CreateReflectionSensor(o, null, prop, attr);
if (sensor != null)
{
sensorsOut.Add(sensor);
}
}
return sensorsOut;
}
/// <summary>
/// Create the ISensor for either the field or property on the provided object.
/// If the data type is unsupported, or the property is write-only, returns null.
/// </summary>
/// <param name="o"></param>
/// <param name="fieldInfo"></param>
/// <param name="propertyInfo"></param>
/// <param name="observableAttribute"></param>
/// <returns></returns>
/// <exception cref="UnityAgentsException"></exception>
static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute)
{
string memberName;
string declaringTypeName;
Type memberType;
if (fieldInfo != null)
{
declaringTypeName = fieldInfo.DeclaringType.Name;
memberName = fieldInfo.Name;
memberType = fieldInfo.FieldType;
}
else
{
declaringTypeName = propertyInfo.DeclaringType.Name;
memberName = propertyInfo.Name;
memberType = propertyInfo.PropertyType;
}
string sensorName;
if (string.IsNullOrEmpty(observableAttribute.m_Name))
{
sensorName = $"ObservableAttribute:{declaringTypeName}.{memberName}";
}
else
{
sensorName = observableAttribute.m_Name;
}
var reflectionSensorInfo = new ReflectionSensorInfo
{
Object = o,
FieldInfo = fieldInfo,
PropertyInfo = propertyInfo,
ObservableAttribute = observableAttribute,
SensorName = sensorName
};
ISensor sensor = null;
if (memberType == typeof(Int32))
{
sensor = new IntReflectionSensor(reflectionSensorInfo);
}
else if (memberType == typeof(float))
{
sensor = new FloatReflectionSensor(reflectionSensorInfo);
}
else if (memberType == typeof(bool))
{
sensor = new BoolReflectionSensor(reflectionSensorInfo);
}
else if (memberType == typeof(Vector2))
{
sensor = new Vector2ReflectionSensor(reflectionSensorInfo);
}
else if (memberType == typeof(Vector3))
{
sensor = new Vector3ReflectionSensor(reflectionSensorInfo);
}
else if (memberType == typeof(Vector4))
{
sensor = new Vector4ReflectionSensor(reflectionSensorInfo);
}
else if (memberType == typeof(Quaternion))
{
sensor = new QuaternionReflectionSensor(reflectionSensorInfo);
}
else
{
// For unsupported types, return null and we'll filter them out later.
return null;
}
// Wrap the base sensor in a StackingSensor if we're using stacking.
if (observableAttribute.m_NumStackedObservations > 1)
{
return new StackingSensor(sensor, observableAttribute.m_NumStackedObservations);
}
return sensor;
}
/// <summary>
/// Gets the sum of the observation sizes of the Observable fields and properties on an object.
/// Also appends errors to the errorsOut array.
/// </summary>
/// <param name="o"></param>
/// <param name="excludeInherited"></param>
/// <param name="errorsOut"></param>
/// <returns></returns>
internal static int GetTotalObservationSize(object o, bool excludeInherited, List<string> errorsOut)
{
int sizeOut = 0;
foreach (var(field, attr) in GetObservableFields(o, excludeInherited))
{
if (s_TypeSizes.ContainsKey(field.FieldType))
{
sizeOut += s_TypeSizes[field.FieldType] * attr.m_NumStackedObservations;
}
else
{
errorsOut.Add($"Unsupported Observable type {field.FieldType.Name} on field {field.Name}");
}
}
foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited))
{
if (s_TypeSizes.ContainsKey(prop.PropertyType))
{
if (prop.CanRead)
{
sizeOut += s_TypeSizes[prop.PropertyType] * attr.m_NumStackedObservations;
}
else
{
errorsOut.Add($"Observable property {prop.Name} is write-only.");
}
}
else
{
errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on property {prop.Name}");
}
}
return sizeOut;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta


fileFormatVersion: 2
guid: a75086dc66a594baea6b8b2935f5dacf
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

22
com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a quaternion field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class QuaternionReflectionSensor : ReflectionSensorBase
{
internal QuaternionReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 4)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var quatVal = (UnityEngine.Quaternion)GetReflectedValue();
writer[0] = quatVal.x;
writer[1] = quatVal.y;
writer[2] = quatVal.z;
writer[3] = quatVal.w;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta


fileFormatVersion: 2
guid: d38241d74074d459bb4590f7f5d16c80
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

97
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


using System.Reflection;
namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Construction info for a ReflectionSensorBase.
/// </summary>
internal struct ReflectionSensorInfo
{
public object Object;
public FieldInfo FieldInfo;
public PropertyInfo PropertyInfo;
public ObservableAttribute ObservableAttribute;
public string SensorName;
}
/// <summary>
/// Abstract base class for reflection-based sensors.
/// </summary>
internal abstract class ReflectionSensorBase : ISensor
{
protected object m_Object;
// Exactly one of m_FieldInfo and m_PropertyInfo should be non-null.
protected FieldInfo m_FieldInfo;
protected PropertyInfo m_PropertyInfo;
// Not currently used, but might want later.
protected ObservableAttribute m_ObservableAttribute;
// Cached sensor names and shapes.
string m_SensorName;
int[] m_Shape;
public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size)
{
m_Object = reflectionSensorInfo.Object;
m_FieldInfo = reflectionSensorInfo.FieldInfo;
m_PropertyInfo = reflectionSensorInfo.PropertyInfo;
m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute;
m_SensorName = reflectionSensorInfo.SensorName;
m_Shape = new[] {size};
}
/// <inheritdoc/>
public int[] GetObservationShape()
{
return m_Shape;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
WriteReflectedField(writer);
return m_Shape[0];
}
internal abstract void WriteReflectedField(ObservationWriter writer);
/// <summary>
/// Get either the reflected field, or return the reflected property.
/// This should be used by implementations in their WriteReflectedField() method.
/// </summary>
/// <returns></returns>
protected object GetReflectedValue()
{
return m_FieldInfo != null ?
m_FieldInfo.GetValue(m_Object) :
m_PropertyInfo.GetMethod.Invoke(m_Object, null);
}
/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
return null;
}
/// <inheritdoc/>
public void Update() {}
/// <inheritdoc/>
public void Reset() {}
/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
/// <inheritdoc/>
public string GetName()
{
return m_SensorName;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta


fileFormatVersion: 2
guid: 6b68d855fb94a45fbbeb0dbe968a35f8
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

20
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a Vector2 field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class Vector2ReflectionSensor : ReflectionSensorBase
{
internal Vector2ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 2)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var vecVal = (UnityEngine.Vector2)GetReflectedValue();
writer[0] = vecVal.x;
writer[1] = vecVal.y;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta


fileFormatVersion: 2
guid: da06ff33f6f2d409cbf240cffa2ba0be
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

21
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a Vector3 field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class Vector3ReflectionSensor : ReflectionSensorBase
{
internal Vector3ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 3)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var vecVal = (UnityEngine.Vector3)GetReflectedValue();
writer[0] = vecVal.x;
writer[1] = vecVal.y;
writer[2] = vecVal.z;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta


fileFormatVersion: 2
guid: e756976ec2a0943cfbc0f97a6550a85b
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

22
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs


namespace Unity.MLAgents.Sensors.Reflection
{
/// <summary>
/// Sensor that wraps a Vector4 field or property of an object, and returns
/// that as an observation.
/// </summary>
internal class Vector4ReflectionSensor : ReflectionSensorBase
{
internal Vector4ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, 4)
{}
internal override void WriteReflectedField(ObservationWriter writer)
{
var vecVal = (UnityEngine.Vector4)GetReflectedValue();
writer[0] = vecVal.x;
writer[1] = vecVal.y;
writer[2] = vecVal.z;
writer[3] = vecVal.w;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta


fileFormatVersion: 2
guid: 01d93aaa1b42b47b8960d303d7c498d3
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存