您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
106 行
2.8 KiB
106 行
2.8 KiB
using NUnit.Framework;
|
|
using Unity.MLAgents.Sensors;
|
|
|
|
namespace Unity.MLAgents.Tests
|
|
{
|
|
public class Float2DSensor : ISensor
|
|
{
|
|
public int Width { get; }
|
|
public int Height { get; }
|
|
string m_Name;
|
|
private ObservationSpec m_ObservationSpec;
|
|
public float[,] floatData;
|
|
|
|
public Float2DSensor(int width, int height, string name)
|
|
{
|
|
Width = width;
|
|
Height = height;
|
|
m_Name = name;
|
|
|
|
m_ObservationSpec = ObservationSpec.Visual(height, width, 1);
|
|
floatData = new float[Height, Width];
|
|
}
|
|
|
|
public Float2DSensor(float[,] floatData, string name)
|
|
{
|
|
this.floatData = floatData;
|
|
Height = floatData.GetLength(0);
|
|
Width = floatData.GetLength(1);
|
|
m_Name = name;
|
|
m_ObservationSpec = ObservationSpec.Visual(Height, Width, 1);
|
|
}
|
|
|
|
public string GetName()
|
|
{
|
|
return m_Name;
|
|
}
|
|
|
|
public ObservationSpec GetObservationSpec()
|
|
{
|
|
return m_ObservationSpec;
|
|
}
|
|
|
|
public byte[] GetCompressedObservation()
|
|
{
|
|
return null;
|
|
}
|
|
|
|
public int Write(ObservationWriter writer)
|
|
{
|
|
using (TimerStack.Instance.Scoped("Float2DSensor.Write"))
|
|
{
|
|
for (var h = 0; h < Height; h++)
|
|
{
|
|
for (var w = 0; w < Width; w++)
|
|
{
|
|
writer[h, w, 0] = floatData[h, w];
|
|
}
|
|
}
|
|
var numWritten = Height * Width;
|
|
return numWritten;
|
|
}
|
|
}
|
|
|
|
public void Update() { }
|
|
public void Reset() { }
|
|
|
|
public CompressionSpec GetCompressionSpec()
|
|
{
|
|
return CompressionSpec.Default();
|
|
}
|
|
}
|
|
|
|
public class FloatVisualSensorTests
|
|
{
|
|
[Test]
|
|
public void TestFloat2DSensorWrite()
|
|
{
|
|
var sensor = new Float2DSensor(3, 4, "floatsensor");
|
|
for (var h = 0; h < 4; h++)
|
|
{
|
|
for (var w = 0; w < 3; w++)
|
|
{
|
|
sensor.floatData[h, w] = 3 * h + w;
|
|
}
|
|
}
|
|
|
|
var output = new float[12];
|
|
var writer = new ObservationWriter();
|
|
writer.SetTarget(output, sensor.GetObservationSpec(), 0);
|
|
sensor.Write(writer);
|
|
for (var i = 0; i < 9; i++)
|
|
{
|
|
Assert.AreEqual(i, output[i]);
|
|
}
|
|
}
|
|
|
|
[Test]
|
|
public void TestFloat2DSensorExternalData()
|
|
{
|
|
var data = new float[4, 3];
|
|
var sensor = new Float2DSensor(data, "floatsensor");
|
|
Assert.AreEqual(sensor.Height, 4);
|
|
Assert.AreEqual(sensor.Width, 3);
|
|
}
|
|
}
|
|
}
|