using NUnit.Framework; using Unity.MLAgents.Actuators; using Unity.MLAgents.Analytics; using Unity.MLAgents.CommunicatorObjects; using Unity.MLAgents.Demonstrations; using Unity.MLAgents.Policies; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Tests { [TestFixture] public class GrpcExtensionsTests { [Test] public void TestDefaultBrainParametersToProto() { // Should be able to convert a default instance to proto. var brain = new BrainParameters(); brain.ToProto("foo", false); } [Test] public void TestDefaultActionSpecToProto() { // Should be able to convert a default instance to proto. var actionSpec = new ActionSpec(); actionSpec.ToBrainParametersProto("foo", false); // Continuous actionSpec = ActionSpec.MakeContinuous(3); actionSpec.ToBrainParametersProto("foo", false); // Discrete actionSpec = ActionSpec.MakeDiscrete(1, 2, 3); actionSpec.ToBrainParametersProto("foo", false); } [Test] public void TestDefaultAgentInfoToProto() { // Should be able to convert a default instance to proto. var agentInfo = new AgentInfo(); agentInfo.ToInfoActionPairProto(); agentInfo.ToAgentInfoProto(); } [Test] public void TestDefaultDemonstrationMetaDataToProto() { // Should be able to convert a default instance to proto. var demoMetaData = new DemonstrationMetaData(); demoMetaData.ToProto(); } class DummySensor : ISensor { public ObservationSpec ObservationSpec; public SensorCompressionType CompressionType; internal DummySensor() { } public ObservationSpec GetObservationSpec() { return ObservationSpec; } public int Write(ObservationWriter writer) { return 0; } public byte[] GetCompressedObservation() { return new byte[] { 13, 37 }; } public void Update() { } public void Reset() { } public SensorCompressionType GetCompressionType() { return CompressionType; } public string GetName() { return "Dummy"; } } class DummySparseChannelSensor : DummySensor, ISparseChannelSensor { public int[] Mapping; internal DummySparseChannelSensor() { } public int[] GetCompressedChannelMapping() { return Mapping; } } [Test] public void TestGetObservationProtoCapabilities() { // Shape, compression type, concatenatedPngObservations, expect throw var variants = new[] { // Vector observations (new[] {3}, SensorCompressionType.None, false, false), // Uncompressed floats (new[] {4, 4, 3}, SensorCompressionType.None, false, false), // Compressed floats, 3 channels (new[] {4, 4, 3}, SensorCompressionType.PNG, false, true), // Compressed floats, >3 channels (new[] {4, 4, 4}, SensorCompressionType.PNG, false, false), // Unsupported - results in uncompressed (new[] {4, 4, 4}, SensorCompressionType.PNG, true, true), // Supported compressed }; foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants) { var dummySensor = new DummySensor(); var obsWriter = new ObservationWriter(); dummySensor.ObservationSpec = ObservationSpec.FromShape(shape); dummySensor.CompressionType = compressionType; obsWriter.SetTarget(new float[128], shape, 0); var caps = new UnityRLCapabilities { ConcatenatedPngObservations = supportsMultiPngObs }; Academy.Instance.TrainerCapabilities = caps; var obsProto = dummySensor.GetObservationProto(obsWriter); if (expectCompressed) { Assert.Greater(obsProto.CompressedData.Length, 0); Assert.AreEqual(obsProto.FloatData, null); } else { Assert.Greater(obsProto.FloatData.Data.Count, 0); Assert.AreEqual(obsProto.CompressedData.Length, 0); } } } [Test] public void TestIsTrivialMapping() { Assert.AreEqual(GrpcExtensions.IsTrivialMapping(new DummySensor()), true); var sparseChannelSensor = new DummySparseChannelSensor(); sparseChannelSensor.Mapping = null; Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true); sparseChannelSensor.Mapping = new[] { 0, 0, 0 }; Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true); sparseChannelSensor.Mapping = new[] { 0, 1, 2, 3, 4 }; Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), true); sparseChannelSensor.Mapping = new[] { 1, 2, 3, 4, -1, -1 }; Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false); sparseChannelSensor.Mapping = new[] { 0, 0, 0, 1, 1, 1 }; Assert.AreEqual(GrpcExtensions.IsTrivialMapping(sparseChannelSensor), false); } [Test] public void TestDefaultTrainingEvents() { var trainingEnvInit = new TrainingEnvironmentInitialized { PythonVersion = "test", }; var trainingEnvInitEvent = trainingEnvInit.ToTrainingEnvironmentInitializedEvent(); Assert.AreEqual(trainingEnvInit.PythonVersion, trainingEnvInitEvent.TrainerPythonVersion); var trainingBehavInit = new TrainingBehaviorInitialized { BehaviorName = "testBehavior", ExtrinsicRewardEnabled = true, CuriosityRewardEnabled = true, RecurrentEnabled = true, SelfPlayEnabled = true, }; var trainingBehavInitEvent = trainingBehavInit.ToTrainingBehaviorInitializedEvent(); Assert.AreEqual(trainingBehavInit.BehaviorName, trainingBehavInitEvent.BehaviorName); Assert.AreEqual(RewardSignals.Extrinsic | RewardSignals.Curiosity, trainingBehavInitEvent.RewardSignalFlags); Assert.AreEqual(TrainingFeatures.Recurrent | TrainingFeatures.SelfPlay, trainingBehavInitEvent.TrainingFeatureFlags); } } }