浏览代码

Enum reflection sensor (#4006)

* WIP enum reflection sensor

* handle unknown values

* changelog

* handle flags
/test-sampler
GitHub 5 年前
当前提交
1577786f
共有 6 个文件被更改,包括 166 次插入19 次删除
  1. 2
      com.unity.ml-agents/CHANGELOG.md
  2. 39
      com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs
  3. 6
      com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
  4. 71
      com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs
  5. 5
      com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
  6. 62
      com.unity.ml-agents/Runtime/Sensors/Reflection/EnumReflectionSensor.cs

2
com.unity.ml-agents/CHANGELOG.md


### Minor Changes
#### com.unity.ml-agents (C#)
- `ObservableAttribute` was added. Adding the attribute to fields or properties on an Agent will allow it to generate
observations via reflection.
observations via reflection. (#3925, #4006)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Curriculum and Parameter Randomization configurations have been merged
into the main training configuration file. Note that this means training

39
com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs


memberType = propertyInfo.PropertyType;
}
if (!s_TypeToSensorInfo.ContainsKey(memberType))
if (!s_TypeToSensorInfo.ContainsKey(memberType) && !memberType.IsEnum)
{
// For unsupported types, return null and we'll filter them out later.
return null;

SensorName = sensorName
};
var (_, sensorType) = s_TypeToSensorInfo[memberType];
var sensor = (ISensor) Activator.CreateInstance(sensorType, reflectionSensorInfo);
ISensor sensor = null;
if (memberType.IsEnum)
{
sensor = new EnumReflectionSensor(reflectionSensorInfo);
}
else
{
var (_, sensorType) = s_TypeToSensorInfo[memberType];
sensor = (ISensor) Activator.CreateInstance(sensorType, reflectionSensorInfo);
}
// Wrap the base sensor in a StackingSensor if we're using stacking.
if (observableAttribute.m_NumStackedObservations > 1)

var (obsSize, _) = s_TypeToSensorInfo[field.FieldType];
sizeOut += obsSize * attr.m_NumStackedObservations;
}
else if (field.FieldType.IsEnum)
{
sizeOut += EnumReflectionSensor.GetEnumObservationSize(field.FieldType);
}
else
{
errorsOut.Add($"Unsupported Observable type {field.FieldType.Name} on field {field.Name}");

foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited))
{
if (s_TypeToSensorInfo.ContainsKey(prop.PropertyType))
if (!prop.CanRead)
{
errorsOut.Add($"Observable property {prop.Name} is write-only.");
}
else if (s_TypeToSensorInfo.ContainsKey(prop.PropertyType))
if (prop.CanRead)
{
var (obsSize, _) = s_TypeToSensorInfo[prop.PropertyType];
sizeOut += obsSize * attr.m_NumStackedObservations;
}
else
{
errorsOut.Add($"Observable property {prop.Name} is write-only.");
}
var (obsSize, _) = s_TypeToSensorInfo[prop.PropertyType];
sizeOut += obsSize * attr.m_NumStackedObservations;
}
else if (prop.PropertyType.IsEnum)
{
sizeOut += EnumReflectionSensor.GetEnumObservationSize(prop.PropertyType);
}
else
{

6
com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs


using System;
using System.Reflection;
namespace Unity.MLAgents.Sensors.Reflection

public PropertyInfo PropertyInfo;
public ObservableAttribute ObservableAttribute;
public string SensorName;
public Type GetMemberType()
{
return FieldInfo != null ? FieldInfo.FieldType : PropertyInfo.PropertyType;
}
}
/// <summary>

71
com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs


[TestFixture]
public class ObservableAttributeTests
{
public enum TestEnum
{
ValueA = -100,
ValueB = 1,
ValueC = 42,
}
[Flags]
public enum TestFlags
{
FlagA = 1,
FlagB = 2,
FlagC = 4
}
class TestClass
{
// Non-observables

get => m_QuaternionProperty;
set => m_QuaternionProperty = value;
}
//
// Enum
//
[Observable("enumMember")]
public TestEnum m_EnumMember = TestEnum.ValueA;
TestEnum m_EnumProperty = TestEnum.ValueC;
[Observable("enumProperty")]
public TestEnum EnumProperty
{
get => m_EnumProperty;
set => m_EnumProperty = value;
}
[Observable("badEnumMember")]
public TestEnum m_BadEnumMember = (TestEnum)1337;
//
// Flags
//
[Observable("flagMember")]
public TestFlags m_FlagMember = TestFlags.FlagA;
TestFlags m_FlagProperty = TestFlags.FlagB | TestFlags.FlagC;
[Observable("flagProperty")]
public TestFlags FlagProperty
{
get => m_FlagProperty;
set => m_FlagProperty = value;
}
}
[Test]

SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] { 5.0f, 5.1f, 5.2f, 5.3f });
SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] { 5.4f, 5.5f, 5.5f, 5.7f });
// Actual ordering is B, C, A
SensorTestHelper.CompareObservation(sensorsByName["enumMember"], new[] { 0.0f, 0.0f, 1.0f });
SensorTestHelper.CompareObservation(sensorsByName["enumProperty"], new[] { 0.0f, 1.0f, 0.0f });
SensorTestHelper.CompareObservation(sensorsByName["badEnumMember"], new[] { 0.0f, 0.0f, 0.0f });
SensorTestHelper.CompareObservation(sensorsByName["flagMember"], new[] { 1.0f, 0.0f, 0.0f });
SensorTestHelper.CompareObservation(sensorsByName["flagProperty"], new[] { 0.0f, 1.0f, 1.0f });
}
[Test]

var errors = new List<string>();
var expectedObsSize = 2 * (1 + 1 + 1 + 2 + 3 + 4 + 4);
var expectedObsSize = 2 * ( // two fields each of these
1 // int
+ 1 // float
+ 1 // bool
+ 2 // vector2
+ 3 // vector3
+ 4 // vector4
+ 4 // quaternion
+ 3 // TestEnum - 3 values
+ 3 // TestFlags - 3 values
)
+ 3; // TestEnum with bad value
Assert.AreEqual(expectedObsSize, ObservableAttribute.GetTotalObservationSize(testClass, false, errors));
Assert.AreEqual(0, errors.Count);
}

5
com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs


Assert.AreEqual(fill, output[0]);
sensor.Write(writer);
for (var i = 0; i < numExpected; i++)
{
Assert.AreEqual(expected[i], output[i]);
}
Assert.AreEqual(expected, output);
}
}

62
com.unity.ml-agents/Runtime/Sensors/Reflection/EnumReflectionSensor.cs


using System;
using UnityEngine;
namespace Unity.MLAgents.Sensors.Reflection
{
internal class EnumReflectionSensor: ReflectionSensorBase
{
Array m_Values;
bool m_IsFlags;
internal EnumReflectionSensor(ReflectionSensorInfo reflectionSensorInfo)
: base(reflectionSensorInfo, GetEnumObservationSize(reflectionSensorInfo.GetMemberType()))
{
var memberType = reflectionSensorInfo.GetMemberType();
m_Values = Enum.GetValues(memberType);
m_IsFlags = memberType.IsDefined(typeof(FlagsAttribute), false);
}
internal override void WriteReflectedField(ObservationWriter writer)
{
// Write the enum value as a one-hot encoding.
// Note that unknown enum values will record all 0's.
// Flags will get treated as a sequence of bools.
var enumValue = (Enum)GetReflectedValue();
int i = 0;
foreach(var val in m_Values)
{
if (m_IsFlags)
{
if (enumValue.HasFlag((Enum)val))
{
writer[i] = 1.0f;
}
else
{
writer[i] = 0.0f;
}
}
else
{
if (val.Equals(enumValue))
{
writer[i] = 1.0f;
}
else
{
writer[i] = 0.0f;
}
}
i++;
}
}
internal static int GetEnumObservationSize(Type t)
{
var values = Enum.GetValues(t);
// Account for all enum values
return values.Length;
}
}
}
正在加载...
取消
保存