Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

99 行
3.0 KiB

using NUnit.Framework;
using Unity.Barracuda;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Inference;
namespace Unity.MLAgents.Tests
{
public class ObservationWriterTests
{
[Test]
public void TestWritesToIList()
{
ObservationWriter writer = new ObservationWriter();
var buffer = new[] { 0f, 0f, 0f };
var shape = new InplaceArray<int>(3);
writer.SetTarget(buffer, shape, 0);
// Elementwise writes
writer[0] = 1f;
writer[2] = 2f;
Assert.AreEqual(new[] { 1f, 0f, 2f }, buffer);
// Elementwise writes with offset
writer.SetTarget(buffer, shape, 1);
writer[0] = 3f;
Assert.AreEqual(new[] { 1f, 3f, 2f }, buffer);
// AddList
writer.SetTarget(buffer, shape, 0);
writer.AddList(new[] { 4f, 5f });
Assert.AreEqual(new[] { 4f, 5f, 2f }, buffer);
// AddList with offset
writer.SetTarget(buffer, shape, 1);
writer.AddList(new[] { 6f, 7f });
Assert.AreEqual(new[] { 4f, 6f, 7f }, buffer);
}
[Test]
public void TestWritesToTensor()
{
ObservationWriter writer = new ObservationWriter();
var t = new TensorProxy
{
valueType = TensorProxy.TensorType.FloatingPoint,
data = new Tensor(2, 3)
};
writer.SetTarget(t, 0, 0);
Assert.AreEqual(0f, t.data[0, 0]);
writer[0] = 1f;
Assert.AreEqual(1f, t.data[0, 0]);
writer.SetTarget(t, 1, 1);
writer[0] = 2f;
writer[1] = 3f;
// [0, 0] shouldn't change
Assert.AreEqual(1f, t.data[0, 0]);
Assert.AreEqual(2f, t.data[1, 1]);
Assert.AreEqual(3f, t.data[1, 2]);
// AddList
t = new TensorProxy
{
valueType = TensorProxy.TensorType.FloatingPoint,
data = new Tensor(2, 3)
};
writer.SetTarget(t, 1, 1);
writer.AddList(new[] { -1f, -2f });
Assert.AreEqual(0f, t.data[0, 0]);
Assert.AreEqual(0f, t.data[0, 1]);
Assert.AreEqual(0f, t.data[0, 2]);
Assert.AreEqual(0f, t.data[1, 0]);
Assert.AreEqual(-1f, t.data[1, 1]);
Assert.AreEqual(-2f, t.data[1, 2]);
}
[Test]
public void TestWritesToTensor3D()
{
ObservationWriter writer = new ObservationWriter();
var t = new TensorProxy
{
valueType = TensorProxy.TensorType.FloatingPoint,
data = new Tensor(2, 2, 2, 3)
};
writer.SetTarget(t, 0, 0);
writer[1, 0, 1] = 1f;
Assert.AreEqual(1f, t.data[0, 1, 0, 1]);
writer.SetTarget(t, 0, 1);
writer[1, 0, 0] = 2f;
Assert.AreEqual(2f, t.data[0, 1, 0, 1]);
}
}
}