浏览代码
VectorSensor and StackedSensor (#2813)
VectorSensor and StackedSensor (#2813)
* WIP VectorSensor and StackedSensor * fix a few dumb mistakes * more VectorSensor * remove Update(), add util methods, hook into TensorGenerator * WriteApdater to write to tensors and arrays * write float observations * used circular buffer for stacked obs * cleanup * fix unit tests * docstrings * undo accidental checkins * rider suggestions, add range check * bounds check before writing * undo ProjectVersion.txt change * fix unit tests * unit test for VectorSensor * StackingSensor tests * missing meta file * missing meta file * WriteAdapter tests/develop-newnormalization
GitHub
5 年前
当前提交
1934bb75
共有 29 个文件被更改,包括 889 次插入 和 168 次删除
-
2UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
-
47UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
-
9UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
-
1UnitySDK/Assets/ML-Agents/Editor/Tests/RayPerceptionTests.cs
-
6UnitySDK/Assets/ML-Agents/Editor/Tests/StandaloneBuildTest.cs
-
112UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
-
2UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
-
40UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
-
2UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelRunner.cs
-
38UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs
-
6UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensor.cs
-
10UnitySDK/Assets/ML-Agents/Scripts/Sensor/ISensor.cs
-
8UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensor.cs
-
17UnitySDK/Assets/ML-Agents/Scripts/Sensor/SensorBase.cs
-
61UnitySDK/Assets/ML-Agents/Scripts/Utilities.cs
-
8UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor.meta
-
109UnitySDK/Assets/ML-Agents/Scripts/Sensor/StackingSensor.cs
-
3UnitySDK/Assets/ML-Agents/Scripts/Sensor/StackingSensor.cs.meta
-
168UnitySDK/Assets/ML-Agents/Scripts/Sensor/VectorSensor.cs
-
3UnitySDK/Assets/ML-Agents/Scripts/Sensor/VectorSensor.cs.meta
-
105UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs
-
3UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs.meta
-
42UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/StackingSensorTests.cs
-
3UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/StackingSensorTests.cs.meta
-
138UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs
-
11UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs.meta
-
100UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs
-
3UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs.meta
|
|||
fileFormatVersion: 2 |
|||
guid: 1b196836e6e3a4361bc62265ec88ebed |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace MLAgents.Sensor |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps around another Sensor to provide temporal stacking.
|
|||
/// Conceptually, consecutive observations are stored left-to-right, which is how they're output
|
|||
/// For example, 4 stacked sets of observations would be output like
|
|||
/// | t = now - 3 | t = now -3 | t = now - 2 | t = now |
|
|||
/// Internally, a circular buffer of arrays is used. The m_CurrentIndex represents the most recent observation.
|
|||
/// </summary>
|
|||
public class StackingSensor : ISensor |
|||
{ |
|||
/// <summary>
|
|||
/// The wrapped sensor.
|
|||
/// </summary>
|
|||
ISensor m_WrappedSensor; |
|||
|
|||
/// <summary>
|
|||
/// Number of stacks to save
|
|||
/// </summary>
|
|||
int m_NumStackedObservations; |
|||
int m_UnstackedObservationSize; |
|||
|
|||
string m_Name; |
|||
int[] m_Shape; |
|||
|
|||
/// <summary>
|
|||
/// Buffer of previous observations
|
|||
/// </summary>
|
|||
float[][] m_StackedObservations; |
|||
|
|||
int m_CurrentIndex; |
|||
WriteAdapter m_LocalAdapter = new WriteAdapter(); |
|||
|
|||
/// <summary>
|
|||
///
|
|||
/// </summary>
|
|||
/// <param name="wrapped">The wrapped sensor</param>
|
|||
/// <param name="numStackedObservations">Number of stacked observations to keep</param>
|
|||
public StackingSensor(ISensor wrapped, int numStackedObservations) |
|||
{ |
|||
// TODO ensure numStackedObservations > 1
|
|||
m_WrappedSensor = wrapped; |
|||
m_NumStackedObservations = numStackedObservations; |
|||
|
|||
m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}"; |
|||
|
|||
var shape = wrapped.GetFloatObservationShape(); |
|||
m_Shape = new int[shape.Length]; |
|||
|
|||
m_UnstackedObservationSize = 1; |
|||
for (int d = 0; d < shape.Length; d++) |
|||
{ |
|||
m_Shape[d] = shape[d]; |
|||
m_UnstackedObservationSize *= shape[d]; |
|||
} |
|||
|
|||
// TODO support arbitrary stacking dimension
|
|||
m_Shape[0] *= numStackedObservations; |
|||
m_StackedObservations = new float[numStackedObservations][]; |
|||
for (var i = 0; i < numStackedObservations; i++) |
|||
{ |
|||
m_StackedObservations[i] = new float[m_UnstackedObservationSize]; |
|||
} |
|||
} |
|||
|
|||
public int Write(WriteAdapter adapter) |
|||
{ |
|||
// First, call the wrapped sensor's write method. Make sure to use our own adapater, not the passed one.
|
|||
m_LocalAdapter.SetTarget(m_StackedObservations[m_CurrentIndex], 0); |
|||
m_WrappedSensor.Write(m_LocalAdapter); |
|||
|
|||
// Now write the saved observations (oldest first)
|
|||
var numWritten = 0; |
|||
for (var i = 0; i < m_NumStackedObservations; i++) |
|||
{ |
|||
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; |
|||
adapter.AddRange(m_StackedObservations[obsIndex], numWritten); |
|||
numWritten += m_UnstackedObservationSize; |
|||
} |
|||
|
|||
// Finally update the index of the "current" buffer.
|
|||
m_CurrentIndex = (m_CurrentIndex + 1) % m_NumStackedObservations; |
|||
return numWritten; |
|||
} |
|||
|
|||
public int[] GetFloatObservationShape() |
|||
{ |
|||
return m_Shape; |
|||
} |
|||
|
|||
public string GetName() |
|||
{ |
|||
return m_Name; |
|||
} |
|||
|
|||
public virtual byte[] GetCompressedObservation() |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
public virtual SensorCompressionType GetCompressionType() |
|||
{ |
|||
return SensorCompressionType.None; |
|||
} |
|||
|
|||
// TODO support stacked compressed observations (byte stream)
|
|||
|
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 8b7a6e88d47d4438ad67e1862566462c |
|||
timeCreated: 1572299581 |
|
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
|
|||
namespace MLAgents.Sensor |
|||
{ |
|||
public class VectorSensor : ISensor |
|||
{ |
|||
// TODO use float[] instead
|
|||
// TOOD allow setting float[]
|
|||
List<float> m_Observations; |
|||
int[] m_Shape; |
|||
string m_Name; |
|||
|
|||
public VectorSensor(int observationSize, string name = null) |
|||
{ |
|||
if (name == null) |
|||
{ |
|||
name = $"VectorSensor_size{observationSize}"; |
|||
} |
|||
|
|||
m_Observations = new List<float>(observationSize); |
|||
m_Name = name; |
|||
m_Shape = new[] { observationSize }; |
|||
} |
|||
|
|||
public int Write(WriteAdapter adapter) |
|||
{ |
|||
var expectedObservations = m_Shape[0]; |
|||
if (m_Observations.Count > expectedObservations) |
|||
{ |
|||
// Too many observations, truncate
|
|||
Debug.LogWarningFormat( |
|||
"More observations ({0}) made than vector observation size ({1}). The observations will be truncated.", |
|||
m_Observations.Count, expectedObservations |
|||
); |
|||
m_Observations.RemoveRange(expectedObservations, m_Observations.Count - expectedObservations); |
|||
} |
|||
else if (m_Observations.Count < expectedObservations) |
|||
{ |
|||
// Not enough observations; pad with zeros.
|
|||
Debug.LogWarningFormat( |
|||
"Fewer observations ({0}) made than vector observation size ({1}). The observations will be padded.", |
|||
m_Observations.Count, expectedObservations |
|||
); |
|||
for (int i = m_Observations.Count; i < expectedObservations; i++) |
|||
{ |
|||
m_Observations.Add(0); |
|||
} |
|||
} |
|||
adapter.AddRange(m_Observations); |
|||
Clear(); |
|||
return expectedObservations; |
|||
} |
|||
|
|||
public int[] GetFloatObservationShape() |
|||
{ |
|||
return m_Shape; |
|||
} |
|||
|
|||
public string GetName() |
|||
{ |
|||
return m_Name; |
|||
} |
|||
|
|||
public virtual byte[] GetCompressedObservation() |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
public virtual SensorCompressionType GetCompressionType() |
|||
{ |
|||
return SensorCompressionType.None; |
|||
} |
|||
|
|||
void Clear() |
|||
{ |
|||
m_Observations.Clear(); |
|||
} |
|||
|
|||
void AddFloatObs(float obs) |
|||
{ |
|||
m_Observations.Add(obs); |
|||
} |
|||
|
|||
// Compatibility methods with Agent observation. These should be removed eventually.
|
|||
|
|||
/// <summary>
|
|||
/// Adds a float observation to the vector observations of the agent.
|
|||
/// </summary>
|
|||
/// <param name="observation">Observation.</param>
|
|||
public void AddObservation(float observation) |
|||
{ |
|||
AddFloatObs(observation); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Adds an integer observation to the vector observations of the agent.
|
|||
/// </summary>
|
|||
/// <param name="observation">Observation.</param>
|
|||
public void AddObservation(int observation) |
|||
{ |
|||
AddFloatObs(observation); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Adds an Vector3 observation to the vector observations of the agent.
|
|||
/// </summary>
|
|||
/// <param name="observation">Observation.</param>
|
|||
public void AddObservation(Vector3 observation) |
|||
{ |
|||
AddFloatObs(observation.x); |
|||
AddFloatObs(observation.y); |
|||
AddFloatObs(observation.z); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Adds an Vector2 observation to the vector observations of the agent.
|
|||
/// </summary>
|
|||
/// <param name="observation">Observation.</param>
|
|||
public void AddObservation(Vector2 observation) |
|||
{ |
|||
AddFloatObs(observation.x); |
|||
AddFloatObs(observation.y); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Adds a collection of float observations to the vector observations of the agent.
|
|||
/// </summary>
|
|||
/// <param name="observation">Observation.</param>
|
|||
public void AddObservation(IEnumerable<float> observation) |
|||
{ |
|||
foreach (var f in observation) |
|||
{ |
|||
AddFloatObs(f); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Adds a quaternion observation to the vector observations of the agent.
|
|||
/// </summary>
|
|||
/// <param name="observation">Observation.</param>
|
|||
public void AddObservation(Quaternion observation) |
|||
{ |
|||
AddFloatObs(observation.x); |
|||
AddFloatObs(observation.y); |
|||
AddFloatObs(observation.z); |
|||
AddFloatObs(observation.w); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Adds a boolean observation to the vector observation of the agent.
|
|||
/// </summary>
|
|||
/// <param name="observation"></param>
|
|||
public void AddObservation(bool observation) |
|||
{ |
|||
AddFloatObs(observation ? 1f : 0f); |
|||
} |
|||
|
|||
|
|||
public void AddOneHotObservation(int observation, int range) |
|||
{ |
|||
for (var i = 0; i < range; i++) |
|||
{ |
|||
AddFloatObs(i == observation ? 1.0f : 0.0f); |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: e3966c9961b343108808d91a4d140a68 |
|||
timeCreated: 1572300800 |
|
|||
using System.Collections.Generic; |
|||
using MLAgents.InferenceBrain; |
|||
|
|||
namespace MLAgents.Sensor |
|||
{ |
|||
/// <summary>
|
|||
/// Allows sensors to write to both TensorProxy and float arrays/lists.
|
|||
/// </summary>
|
|||
public class WriteAdapter |
|||
{ |
|||
IList<float> m_Data; |
|||
int m_Offset; |
|||
|
|||
TensorProxy m_Proxy; |
|||
int m_Batch; |
|||
|
|||
/// <summary>
|
|||
/// Set the adapter to write to an IList at the given channelOffset.
|
|||
/// </summary>
|
|||
/// <param name="data"></param>
|
|||
/// <param name="offset"></param>
|
|||
public void SetTarget(IList<float> data, int offset) |
|||
{ |
|||
m_Data = data; |
|||
m_Offset = offset; |
|||
m_Proxy = null; |
|||
m_Batch = -1; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Set the adapter to write to a TensorProxy at the given batch and channel offset.
|
|||
/// </summary>
|
|||
/// <param name="tensorProxy"></param>
|
|||
/// <param name="batchIndex"></param>
|
|||
/// <param name="channelOffset"></param>
|
|||
public void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset) |
|||
{ |
|||
m_Proxy = tensorProxy; |
|||
m_Batch = batchIndex; |
|||
m_Offset = channelOffset; |
|||
m_Data = null; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// 1D write access at a specified index. Use AddRange if possible instead.
|
|||
/// </summary>
|
|||
/// <param name="index">Index to write to</param>
|
|||
public float this[int index] |
|||
{ |
|||
set |
|||
{ |
|||
if (m_Data != null) |
|||
{ |
|||
m_Data[index + m_Offset] = value; |
|||
} |
|||
else |
|||
{ |
|||
m_Proxy.data[m_Batch, index + m_Offset] = value; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// 3D write access at the specified height, width, and channel. Only usable with a TensorProxy target.
|
|||
/// </summary>
|
|||
/// <param name="h"></param>
|
|||
/// <param name="w"></param>
|
|||
/// <param name="ch"></param>
|
|||
public float this[int h, int w, int ch] |
|||
{ |
|||
set |
|||
{ |
|||
// Only TensorProxy supports 3D access
|
|||
m_Proxy.data[m_Batch, h, w, ch + m_Offset] = value; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Write the range of floats
|
|||
/// </summary>
|
|||
/// <param name="data"></param>
|
|||
/// <param name="writeOffset">Optional write offset</param>
|
|||
public void AddRange(IEnumerable<float> data, int writeOffset = 0) |
|||
{ |
|||
if (m_Data != null) |
|||
{ |
|||
int index = 0; |
|||
foreach (var val in data) |
|||
{ |
|||
m_Data[index + m_Offset + writeOffset] = val; |
|||
index++; |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
int index = 0; |
|||
foreach (var val in data) |
|||
{ |
|||
m_Proxy.data[m_Batch, index + m_Offset + writeOffset] = val; |
|||
index++; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 86bad2e6dded4a62853752a1713981f2 |
|||
timeCreated: 1572540197 |
|
|||
using NUnit.Framework; |
|||
using UnityEngine; |
|||
using MLAgents.Sensor; |
|||
|
|||
namespace MLAgents.Tests |
|||
{ |
|||
public class StackingSensorTests |
|||
{ |
|||
[Test] |
|||
public void TestCtor() |
|||
{ |
|||
ISensor wrapped = new VectorSensor(4); |
|||
ISensor sensor = new StackingSensor(wrapped, 4); |
|||
Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName()); |
|||
Assert.AreEqual(sensor.GetFloatObservationShape(), new [] {16}); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestStacking() |
|||
{ |
|||
VectorSensor wrapped = new VectorSensor(2); |
|||
ISensor sensor = new StackingSensor(wrapped, 3); |
|||
|
|||
wrapped.AddObservation(new [] {1f, 2f}); |
|||
SensorTestHelper.CompareObservation(sensor, new [] {0f, 0f, 0f, 0f, 1f, 2f}); |
|||
|
|||
wrapped.AddObservation(new [] {3f, 4f}); |
|||
SensorTestHelper.CompareObservation(sensor, new [] {0f, 0f, 1f, 2f, 3f, 4f}); |
|||
|
|||
wrapped.AddObservation(new [] {5f, 6f}); |
|||
SensorTestHelper.CompareObservation(sensor, new [] {1f, 2f, 3f, 4f, 5f, 6f}); |
|||
|
|||
wrapped.AddObservation(new [] {7f, 8f}); |
|||
SensorTestHelper.CompareObservation(sensor, new [] {3f, 4f, 5f, 6f, 7f, 8f}); |
|||
|
|||
wrapped.AddObservation(new [] {9f, 10f}); |
|||
SensorTestHelper.CompareObservation(sensor, new [] {5f, 6f, 7f, 8f, 9f, 10f}); |
|||
} |
|||
|
|||
|
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 7b071fdf91474d18a05ea20175c6b3bd |
|||
timeCreated: 1572564843 |
|
|||
using NUnit.Framework; |
|||
using UnityEngine; |
|||
using MLAgents.Sensor; |
|||
|
|||
namespace MLAgents.Tests |
|||
{ |
|||
public class SensorTestHelper |
|||
{ |
|||
public static void CompareObservation(ISensor sensor, float[] expected) |
|||
{ |
|||
var numExpected = expected.Length; |
|||
const float fill = -1337f; |
|||
var output = new float[numExpected]; |
|||
for (var i = 0; i < numExpected; i++) |
|||
{ |
|||
output[i] = fill; |
|||
} |
|||
Assert.AreEqual(fill, output[0]); |
|||
|
|||
WriteAdapter writer = new WriteAdapter(); |
|||
writer.SetTarget(output, 0); |
|||
|
|||
// Make sure WriteAdapter didn't touch anything
|
|||
Assert.AreEqual(fill, output[0]); |
|||
|
|||
sensor.Write(writer); |
|||
for (var i = 0; i < numExpected; i++) |
|||
{ |
|||
Assert.AreEqual(expected[i], output[i]); |
|||
} |
|||
} |
|||
} |
|||
|
|||
public class VectorSensorTests |
|||
{ |
|||
[Test] |
|||
public void TestCtor() |
|||
{ |
|||
ISensor sensor = new VectorSensor(4); |
|||
Assert.AreEqual("VectorSensor_size4", sensor.GetName()); |
|||
|
|||
sensor = new VectorSensor(3, "test_sensor"); |
|||
Assert.AreEqual("test_sensor", sensor.GetName()); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestWrite() |
|||
{ |
|||
var sensor = new VectorSensor(4); |
|||
sensor.AddObservation(1f); |
|||
sensor.AddObservation(2f); |
|||
sensor.AddObservation(3f); |
|||
sensor.AddObservation(4f); |
|||
|
|||
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f }); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestAddObservationFloat() |
|||
{ |
|||
var sensor = new VectorSensor(1); |
|||
sensor.AddObservation(1.2f); |
|||
SensorTestHelper.CompareObservation(sensor, new []{1.2f}); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestAddObservationInt() |
|||
{ |
|||
var sensor = new VectorSensor(1); |
|||
sensor.AddObservation(42); |
|||
SensorTestHelper.CompareObservation(sensor, new []{42f}); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestAddObservationVec() |
|||
{ |
|||
var sensor = new VectorSensor(3); |
|||
sensor.AddObservation(new Vector3(1,2,3)); |
|||
SensorTestHelper.CompareObservation(sensor, new []{1f, 2f, 3f}); |
|||
|
|||
sensor = new VectorSensor(2); |
|||
sensor.AddObservation(new Vector2(4,5)); |
|||
SensorTestHelper.CompareObservation(sensor, new[] { 4f, 5f }); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestAddObservationQuaternion() |
|||
{ |
|||
var sensor = new VectorSensor(4); |
|||
sensor.AddObservation(Quaternion.identity); |
|||
SensorTestHelper.CompareObservation(sensor, new []{0f, 0f, 0f, 1f}); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestWriteEnumerable() |
|||
{ |
|||
var sensor = new VectorSensor(4); |
|||
sensor.AddObservation(new [] {1f, 2f, 3f, 4f}); |
|||
|
|||
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f }); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestAddObservationBool() |
|||
{ |
|||
var sensor = new VectorSensor(1); |
|||
sensor.AddObservation(true); |
|||
SensorTestHelper.CompareObservation(sensor, new []{1f}); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestAddObservationOneHot() |
|||
{ |
|||
var sensor = new VectorSensor(4); |
|||
sensor.AddOneHotObservation(2, 4); |
|||
SensorTestHelper.CompareObservation(sensor, new []{0f, 0f, 1f, 0f}); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestWriteTooMany() |
|||
{ |
|||
var sensor = new VectorSensor(2); |
|||
sensor.AddObservation(new [] {1f, 2f, 3f, 4f}); |
|||
|
|||
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f}); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestWriteNotEnough() |
|||
{ |
|||
var sensor = new VectorSensor(4); |
|||
sensor.AddObservation(new [] {1f, 2f}); |
|||
|
|||
// Make sure extra zeros are added
|
|||
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 0f, 0f}); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 18c0d390ce4c5464ab48b96db0392eb0 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using NUnit.Framework; |
|||
using UnityEngine; |
|||
using MLAgents.Sensor; |
|||
|
|||
using Barracuda; |
|||
using MLAgents.InferenceBrain; |
|||
using MLAgents.InferenceBrain.Utils; |
|||
|
|||
|
|||
namespace MLAgents.Tests |
|||
{ |
|||
public class WriteAdapterTests |
|||
{ |
|||
[Test] |
|||
public void TestWritesToIList() |
|||
{ |
|||
WriteAdapter writer = new WriteAdapter(); |
|||
var buffer = new[] { 0f, 0f, 0f }; |
|||
|
|||
writer.SetTarget(buffer, 0); |
|||
// Elementwise writes
|
|||
writer[0] = 1f; |
|||
writer[2] = 2f; |
|||
Assert.AreEqual(new[] { 1f, 0f, 2f }, buffer); |
|||
|
|||
// Elementwise writes with offset
|
|||
writer.SetTarget(buffer, 1); |
|||
writer[0] = 3f; |
|||
Assert.AreEqual(new[] { 1f, 3f, 2f }, buffer); |
|||
|
|||
// AddRange
|
|||
writer.SetTarget(buffer, 0); |
|||
writer.AddRange(new [] {4f, 5f}); |
|||
Assert.AreEqual(new[] { 4f, 5f, 2f }, buffer); |
|||
|
|||
// AddRange with offset
|
|||
writer.SetTarget(buffer, 1); |
|||
writer.AddRange(new [] {6f, 7f}); |
|||
Assert.AreEqual(new[] { 4f, 6f, 7f }, buffer); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestWritesToTensor() |
|||
{ |
|||
WriteAdapter writer = new WriteAdapter(); |
|||
var t = new TensorProxy |
|||
{ |
|||
valueType = TensorProxy.TensorType.FloatingPoint, |
|||
data = new Tensor(2, 3) |
|||
}; |
|||
writer.SetTarget(t, 0, 0); |
|||
Assert.AreEqual(0f, t.data[0, 0]); |
|||
writer[0] = 1f; |
|||
Assert.AreEqual(1f, t.data[0, 0]); |
|||
|
|||
writer.SetTarget(t, 1, 1); |
|||
writer[0] = 2f; |
|||
writer[1] = 3f; |
|||
// [0, 0] shouldn't change
|
|||
Assert.AreEqual(1f, t.data[0, 0]); |
|||
Assert.AreEqual(2f, t.data[1, 1]); |
|||
Assert.AreEqual(3f, t.data[1, 2]); |
|||
|
|||
// AddRange
|
|||
t = new TensorProxy |
|||
{ |
|||
valueType = TensorProxy.TensorType.FloatingPoint, |
|||
data = new Tensor(2, 3) |
|||
}; |
|||
|
|||
writer.SetTarget(t, 1, 1); |
|||
writer.AddRange(new [] {-1f, -2f}); |
|||
Assert.AreEqual(0f, t.data[0, 0]); |
|||
Assert.AreEqual(0f, t.data[0, 1]); |
|||
Assert.AreEqual(0f, t.data[0, 2]); |
|||
Assert.AreEqual(0f, t.data[1, 0]); |
|||
Assert.AreEqual(-1f, t.data[1, 1]); |
|||
Assert.AreEqual(-2f, t.data[1, 2]); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestWritesToTensor3D() |
|||
{ |
|||
WriteAdapter writer = new WriteAdapter(); |
|||
var t = new TensorProxy |
|||
{ |
|||
valueType = TensorProxy.TensorType.FloatingPoint, |
|||
data = new Tensor(2, 2, 2, 3) |
|||
}; |
|||
|
|||
writer.SetTarget(t, 0, 0); |
|||
writer[1, 0, 1] = 1f; |
|||
Assert.AreEqual(1f, t.data[0, 1, 0, 1]); |
|||
|
|||
writer.SetTarget(t, 0, 1); |
|||
writer[1, 0, 0] = 2f; |
|||
Assert.AreEqual(2f, t.data[0, 1, 0, 1]); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 3de9cbda816e4d7b907e765577dd54f7 |
|||
timeCreated: 1572568337 |
撰写
预览
正在加载...
取消
保存
Reference in new issue