Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

173 行
5.6 KiB

using NUnit.Framework;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class GrpcExtensionsTests
{
[Test]
public void TestDefaultBrainParametersToProto()
{
// Should be able to convert a default instance to proto.
var brain = new BrainParameters();
brain.ToProto("foo", false);
}
[Test]
public void TestDefaultActionSpecToProto()
{
// Should be able to convert a default instance to proto.
var actionSpec = new ActionSpec();
actionSpec.ToBrainParametersProto("foo", false);
// Continuous
actionSpec = ActionSpec.MakeContinuous(3);
actionSpec.ToBrainParametersProto("foo", false);
// Discrete
actionSpec = ActionSpec.MakeDiscrete(1, 2, 3);
actionSpec.ToBrainParametersProto("foo", false);
}
[Test]
public void TestDefaultAgentInfoToProto()
{
// Should be able to convert a default instance to proto.
var agentInfo = new AgentInfo();
agentInfo.ToInfoActionPairProto();
agentInfo.ToAgentInfoProto();
}
[Test]
public void TestDefaultDemonstrationMetaDataToProto()
{
// Should be able to convert a default instance to proto.
var demoMetaData = new DemonstrationMetaData();
demoMetaData.ToProto();
}
class DummySensor : ISensor
{
public int[] Shape;
public SensorCompressionType CompressionType;
internal DummySensor()
{
}
public int[] GetObservationShape()
{
return Shape;
}
public int Write(ObservationWriter writer)
{
return 0;
}
public byte[] GetCompressedObservation()
{
return new byte[] { 13, 37 };
}
public void Update() { }
public void Reset() { }
public SensorCompressionType GetCompressionType()
{
return CompressionType;
}
public string GetName()
{
return "Dummy";
}
}
class DummySparseChannelSensor : DummySensor, ISparseChannelSensor
{
public int[] Mapping;
internal DummySparseChannelSensor()
{
}
public int[] GetCompressedChannelMapping()
{
return Mapping;
}
}
[Test]
public void TestGetObservationProtoCapabilities()
{
// Shape, compression type, concatenatedPngObservations, expect throw
var variants = new[]
{
// Vector observations
(new[] {3}, SensorCompressionType.None, false, false),
// Uncompressed floats
(new[] {4, 4, 3}, SensorCompressionType.None, false, false),
// Compressed floats, 3 channels
(new[] {4, 4, 3}, SensorCompressionType.PNG, false, true),
// Compressed floats, >3 channels
(new[] {4, 4, 4}, SensorCompressionType.PNG, false, false), // Unsupported - results in uncompressed
(new[] {4, 4, 4}, SensorCompressionType.PNG, true, true), // Supported compressed
};
foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants)
{
var dummySensor = new DummySensor();
var obsWriter = new ObservationWriter();
dummySensor.Shape = shape;
dummySensor.CompressionType = compressionType;
obsWriter.SetTarget(new float[128], shape, 0);
var caps = new UnityRLCapabilities
{
ConcatenatedPngObservations = supportsMultiPngObs
};
Academy.Instance.TrainerCapabilities = caps;
var obsProto = dummySensor.GetObservationProto(obsWriter);
if (expectCompressed)
{
Assert.Greater(obsProto.CompressedData.Length, 0);
Assert.AreEqual(obsProto.FloatData, null);
}
else
{
Assert.Greater(obsProto.FloatData.Data.Count, 0);
Assert.AreEqual(obsProto.CompressedData.Length, 0);
}
}
}
[Test]
public void TestIsTrivialMapping()
{
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(new DummySensor()), true);
var sparseChannelSensor = new DummySparseChannelSensor();
sparseChannelSensor.Mapping = null;
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
sparseChannelSensor.Mapping = new[] { 0, 0, 0 };
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
sparseChannelSensor.Mapping = new[] { 0, 1, 2, 3, 4 };
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true);
sparseChannelSensor.Mapping = new[] { 1, 2, 3, 4, -1, -1 };
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false);
sparseChannelSensor.Mapping = new[] { 0, 0, 0, 1, 1, 1 };
Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false);
}
}
}