您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
131 行
4.3 KiB
131 行
4.3 KiB
using NUnit.Framework;
|
|
using UnityEngine;
|
|
using Unity.MLAgents.Sensors;
|
|
|
|
namespace Unity.MLAgents.Tests
|
|
{
|
|
public class VectorSensorTests
|
|
{
|
|
[Test]
|
|
public void TestCtor()
|
|
{
|
|
ISensor sensor = new VectorSensor(4);
|
|
Assert.AreEqual("VectorSensor_size4", sensor.GetName());
|
|
|
|
sensor = new VectorSensor(3, "test_sensor");
|
|
Assert.AreEqual("test_sensor", sensor.GetName());
|
|
}
|
|
|
|
[Test]
|
|
public void TestWrite()
|
|
{
|
|
var sensor = new VectorSensor(4);
|
|
sensor.AddObservation(1f);
|
|
sensor.AddObservation(2f);
|
|
sensor.AddObservation(3f);
|
|
sensor.AddObservation(4f);
|
|
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f });
|
|
// Check that if we don't call Update(), the same observations are produced
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f });
|
|
|
|
// Check that Update() clears the data
|
|
sensor.Update();
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f });
|
|
}
|
|
|
|
[Test]
|
|
public void TestAddObservationFloat()
|
|
{
|
|
var sensor = new VectorSensor(1);
|
|
sensor.AddObservation(1.2f);
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 1.2f });
|
|
}
|
|
|
|
[Test]
|
|
public void TestObservationType()
|
|
{
|
|
var sensor = new VectorSensor(1);
|
|
var spec = sensor.GetObservationSpec();
|
|
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
|
|
sensor = new VectorSensor(1, observationType: ObservationType.Default);
|
|
spec = sensor.GetObservationSpec();
|
|
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
|
|
sensor = new VectorSensor(1, observationType: ObservationType.Goal);
|
|
spec = sensor.GetObservationSpec();
|
|
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Goal);
|
|
}
|
|
|
|
[Test]
|
|
public void TestAddObservationInt()
|
|
{
|
|
var sensor = new VectorSensor(1);
|
|
sensor.AddObservation(42);
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 42f });
|
|
}
|
|
|
|
[Test]
|
|
public void TestAddObservationVec()
|
|
{
|
|
var sensor = new VectorSensor(3);
|
|
sensor.AddObservation(new Vector3(1, 2, 3));
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f });
|
|
|
|
sensor = new VectorSensor(2);
|
|
sensor.AddObservation(new Vector2(4, 5));
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 4f, 5f });
|
|
}
|
|
|
|
[Test]
|
|
public void TestAddObservationQuaternion()
|
|
{
|
|
var sensor = new VectorSensor(4);
|
|
sensor.AddObservation(Quaternion.identity);
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 1f });
|
|
}
|
|
|
|
[Test]
|
|
public void TestWriteEnumerable()
|
|
{
|
|
var sensor = new VectorSensor(4);
|
|
sensor.AddObservation(new[] { 1f, 2f, 3f, 4f });
|
|
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f });
|
|
}
|
|
|
|
[Test]
|
|
public void TestAddObservationBool()
|
|
{
|
|
var sensor = new VectorSensor(1);
|
|
sensor.AddObservation(true);
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 1f });
|
|
}
|
|
|
|
[Test]
|
|
public void TestAddObservationOneHot()
|
|
{
|
|
var sensor = new VectorSensor(4);
|
|
sensor.AddOneHotObservation(2, 4);
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 0f });
|
|
}
|
|
|
|
[Test]
|
|
public void TestWriteTooMany()
|
|
{
|
|
var sensor = new VectorSensor(2);
|
|
sensor.AddObservation(new[] { 1f, 2f, 3f, 4f });
|
|
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f });
|
|
}
|
|
|
|
[Test]
|
|
public void TestWriteNotEnough()
|
|
{
|
|
var sensor = new VectorSensor(4);
|
|
sensor.AddObservation(new[] { 1f, 2f });
|
|
|
|
// Make sure extra zeros are added
|
|
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 0f, 0f });
|
|
}
|
|
}
|
|
}
|