您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
99 行
3.0 KiB
99 行
3.0 KiB
using NUnit.Framework;
|
|
using Unity.Barracuda;
|
|
using Unity.MLAgents.Sensors;
|
|
using Unity.MLAgents.Inference;
|
|
|
|
|
|
namespace Unity.MLAgents.Tests
|
|
{
|
|
public class ObservationWriterTests
|
|
{
|
|
[Test]
|
|
public void TestWritesToIList()
|
|
{
|
|
ObservationWriter writer = new ObservationWriter();
|
|
var buffer = new[] { 0f, 0f, 0f };
|
|
var shape = new InplaceArray<int>(3);
|
|
|
|
writer.SetTarget(buffer, shape, 0);
|
|
// Elementwise writes
|
|
writer[0] = 1f;
|
|
writer[2] = 2f;
|
|
Assert.AreEqual(new[] { 1f, 0f, 2f }, buffer);
|
|
|
|
// Elementwise writes with offset
|
|
writer.SetTarget(buffer, shape, 1);
|
|
writer[0] = 3f;
|
|
Assert.AreEqual(new[] { 1f, 3f, 2f }, buffer);
|
|
|
|
// AddList
|
|
writer.SetTarget(buffer, shape, 0);
|
|
writer.AddList(new[] { 4f, 5f });
|
|
Assert.AreEqual(new[] { 4f, 5f, 2f }, buffer);
|
|
|
|
// AddList with offset
|
|
writer.SetTarget(buffer, shape, 1);
|
|
writer.AddList(new[] { 6f, 7f });
|
|
Assert.AreEqual(new[] { 4f, 6f, 7f }, buffer);
|
|
}
|
|
|
|
[Test]
|
|
public void TestWritesToTensor()
|
|
{
|
|
ObservationWriter writer = new ObservationWriter();
|
|
var t = new TensorProxy
|
|
{
|
|
valueType = TensorProxy.TensorType.FloatingPoint,
|
|
data = new Tensor(2, 3)
|
|
};
|
|
|
|
writer.SetTarget(t, 0, 0);
|
|
Assert.AreEqual(0f, t.data[0, 0]);
|
|
writer[0] = 1f;
|
|
Assert.AreEqual(1f, t.data[0, 0]);
|
|
|
|
writer.SetTarget(t, 1, 1);
|
|
writer[0] = 2f;
|
|
writer[1] = 3f;
|
|
// [0, 0] shouldn't change
|
|
Assert.AreEqual(1f, t.data[0, 0]);
|
|
Assert.AreEqual(2f, t.data[1, 1]);
|
|
Assert.AreEqual(3f, t.data[1, 2]);
|
|
|
|
// AddList
|
|
t = new TensorProxy
|
|
{
|
|
valueType = TensorProxy.TensorType.FloatingPoint,
|
|
data = new Tensor(2, 3)
|
|
};
|
|
|
|
writer.SetTarget(t, 1, 1);
|
|
writer.AddList(new[] { -1f, -2f });
|
|
Assert.AreEqual(0f, t.data[0, 0]);
|
|
Assert.AreEqual(0f, t.data[0, 1]);
|
|
Assert.AreEqual(0f, t.data[0, 2]);
|
|
Assert.AreEqual(0f, t.data[1, 0]);
|
|
Assert.AreEqual(-1f, t.data[1, 1]);
|
|
Assert.AreEqual(-2f, t.data[1, 2]);
|
|
}
|
|
|
|
[Test]
|
|
public void TestWritesToTensor3D()
|
|
{
|
|
ObservationWriter writer = new ObservationWriter();
|
|
var t = new TensorProxy
|
|
{
|
|
valueType = TensorProxy.TensorType.FloatingPoint,
|
|
data = new Tensor(2, 2, 2, 3)
|
|
};
|
|
|
|
writer.SetTarget(t, 0, 0);
|
|
writer[1, 0, 1] = 1f;
|
|
Assert.AreEqual(1f, t.data[0, 1, 0, 1]);
|
|
|
|
writer.SetTarget(t, 0, 1);
|
|
writer[1, 0, 0] = 2f;
|
|
Assert.AreEqual(2f, t.data[0, 1, 0, 1]);
|
|
}
|
|
}
|
|
}
|