Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

276 行
10 KiB

using System;
using System.Text.RegularExpressions;
using Google.Protobuf;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Demonstrations;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Analytics;
using Unity.MLAgents.CommunicatorObjects;
using UnityEngine;
using UnityEngine.TestTools;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class GrpcExtensionsTests
{
[SetUp]
public void SetUp()
{
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities();
}
[Test]
public void TestDefaultBrainParametersToProto()
{
// Should be able to convert a default instance to proto.
var brain = new BrainParameters();
brain.ToProto("foo", false);
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities
{
BaseRLCapabilities = true,
HybridActions = false
};
brain.ToProto("foo", false);
}
[Test]
public void TestDefaultActionSpecToProto()
{
// Should be able to convert a default instance to proto.
var actionSpec = new ActionSpec();
actionSpec.ToBrainParametersProto("foo", false);
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities
{
BaseRLCapabilities = true,
HybridActions = false
};
actionSpec.ToBrainParametersProto("foo", false);
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities();
// Continuous
actionSpec = ActionSpec.MakeContinuous(3);
actionSpec.ToBrainParametersProto("foo", false);
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities
{
BaseRLCapabilities = true,
HybridActions = false
};
actionSpec.ToBrainParametersProto("foo", false);
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities();
// Discrete
actionSpec = ActionSpec.MakeDiscrete(1, 2, 3);
actionSpec.ToBrainParametersProto("foo", false);
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities
{
BaseRLCapabilities = true,
HybridActions = false
};
actionSpec.ToBrainParametersProto("foo", false);
}
[Test]
public void ToBrainParameters()
{
// Should be able to convert a default instance to proto.
var actionSpec = new ActionSpec();
actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters();
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities
{
BaseRLCapabilities = true,
HybridActions = false
};
actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters();
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities();
// Continuous
actionSpec = ActionSpec.MakeContinuous(3);
actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters();
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities
{
BaseRLCapabilities = true,
HybridActions = false
};
actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters();
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities();
// Discrete
actionSpec = ActionSpec.MakeDiscrete(1, 2, 3);
actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters();
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities
{
BaseRLCapabilities = true,
HybridActions = false
};
actionSpec.ToBrainParametersProto("foo", false).ToBrainParameters();
}
[Test]
public void TestDefaultAgentInfoToProto()
{
// Should be able to convert a default instance to proto.
var agentInfo = new AgentInfo();
var pairProto = agentInfo.ToInfoActionPairProto();
pairProto.AgentInfo.Observations.Add(new ObservationProto
{
CompressedData = ByteString.Empty,
CompressionType = CompressionTypeProto.None,
FloatData = new ObservationProto.Types.FloatData(),
ObservationType = ObservationTypeProto.Default,
Name = "Sensor"
});
pairProto.AgentInfo.Observations[0].Shape.Add(0);
pairProto.GetObservationSummaries();
agentInfo.ToAgentInfoProto();
agentInfo.groupId = 1;
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities
{
BaseRLCapabilities = true,
MultiAgentGroups = false
};
agentInfo.ToAgentInfoProto();
LogAssert.Expect(LogType.Warning, new Regex(".+"));
Academy.Instance.TrainerCapabilities = new UnityRLCapabilities
{
BaseRLCapabilities = true,
MultiAgentGroups = true
};
agentInfo.ToAgentInfoProto();
}
[Test]
public void TestDefaultDemonstrationMetaDataToProto()
{
// Should be able to convert a default instance to proto.
var demoMetaData = new DemonstrationMetaData();
demoMetaData.ToProto();
}
class DummySensor : ISensor
{
public ObservationSpec ObservationSpec;
public SensorCompressionType CompressionType;
public ObservationSpec GetObservationSpec()
{
return ObservationSpec;
}
public int Write(ObservationWriter writer)
{
return 0;
}
public byte[] GetCompressedObservation()
{
return new byte[] { 13, 37 };
}
public void Update() { }
public void Reset() { }
public CompressionSpec GetCompressionSpec()
{
return new CompressionSpec(CompressionType);
}
public string GetName()
{
return "Dummy";
}
}
[Test]
public void TestGetObservationProtoCapabilities()
{
// Shape, compression type, concatenatedPngObservations, expect throw
var variants = new[]
{
// Vector observations
(new[] {3}, SensorCompressionType.None, false, false),
// Uncompressed floats
(new[] {4, 4, 3}, SensorCompressionType.None, false, false),
// Compressed floats, 3 channels
(new[] {4, 4, 3}, SensorCompressionType.PNG, false, true),
// Compressed floats, >3 channels
(new[] {4, 4, 4}, SensorCompressionType.PNG, false, false), // Unsupported - results in uncompressed
(new[] {4, 4, 4}, SensorCompressionType.PNG, true, true), // Supported compressed
};
foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants)
{
var inplaceShape = InplaceArray<int>.FromList(shape);
var dummySensor = new DummySensor();
var obsWriter = new ObservationWriter();
if (shape.Length == 1)
{
dummySensor.ObservationSpec = ObservationSpec.Vector(shape[0]);
}
else if (shape.Length == 3)
{
dummySensor.ObservationSpec = ObservationSpec.Visual(shape[0], shape[1], shape[2]);
}
else
{
throw new ArgumentOutOfRangeException();
}
dummySensor.CompressionType = compressionType;
obsWriter.SetTarget(new float[128], inplaceShape, 0);
var caps = new UnityRLCapabilities
{
ConcatenatedPngObservations = supportsMultiPngObs
};
Academy.Instance.TrainerCapabilities = caps;
var obsProto = dummySensor.GetObservationProto(obsWriter);
if (expectCompressed)
{
Assert.Greater(obsProto.CompressedData.Length, 0);
Assert.AreEqual(obsProto.FloatData, null);
}
else
{
Assert.Greater(obsProto.FloatData.Data.Count, 0);
Assert.AreEqual(obsProto.CompressedData.Length, 0);
}
}
}
[Test]
public void TestDefaultTrainingEvents()
{
var trainingEnvInit = new TrainingEnvironmentInitialized
{
PythonVersion = "test",
};
var trainingEnvInitEvent = trainingEnvInit.ToTrainingEnvironmentInitializedEvent();
Assert.AreEqual(trainingEnvInit.PythonVersion, trainingEnvInitEvent.TrainerPythonVersion);
var trainingBehavInit = new TrainingBehaviorInitialized
{
BehaviorName = "testBehavior",
ExtrinsicRewardEnabled = true,
CuriosityRewardEnabled = true,
RecurrentEnabled = true,
SelfPlayEnabled = true,
};
var trainingBehavInitEvent = trainingBehavInit.ToTrainingBehaviorInitializedEvent();
Assert.AreEqual(trainingBehavInit.BehaviorName, trainingBehavInitEvent.BehaviorName);
Assert.AreEqual(RewardSignals.Extrinsic | RewardSignals.Curiosity, trainingBehavInitEvent.RewardSignalFlags);
Assert.AreEqual(TrainingFeatures.Recurrent | TrainingFeatures.SelfPlay, trainingBehavInitEvent.TrainingFeatureFlags);
}
}
}