using NUnit.Framework; using UnityEngine; using Unity.MLAgents.Policies; using Unity.MLAgents.Demonstrations; using Unity.MLAgents.Actuators; using Unity.MLAgents.Sensors; namespace Unity.MLAgents.Tests { [TestFixture] public class GrpcExtensionsTests { [Test] public void TestDefaultBrainParametersToProto() { // Should be able to convert a default instance to proto. var brain = new BrainParameters(); brain.ToProto("foo", false); } [Test] public void TestDefaultActionSpecToProto() { // Should be able to convert a default instance to proto. var actionSpec = new ActionSpec(); actionSpec.ToBrainParametersProto("foo", false); // Continuous actionSpec = ActionSpec.MakeContinuous(3); actionSpec.ToBrainParametersProto("foo", false); // Discrete actionSpec = ActionSpec.MakeDiscrete(1, 2, 3); actionSpec.ToBrainParametersProto("foo", false); } [Test] public void TestDefaultAgentInfoToProto() { // Should be able to convert a default instance to proto. var agentInfo = new AgentInfo(); agentInfo.ToInfoActionPairProto(); agentInfo.ToAgentInfoProto(); } [Test] public void TestDefaultDemonstrationMetaDataToProto() { // Should be able to convert a default instance to proto. var demoMetaData = new DemonstrationMetaData(); demoMetaData.ToProto(); } class DummySensor : ISensor { public int[] Shape; public SensorCompressionType CompressionType; internal DummySensor() { } public int[] GetObservationShape() { return Shape; } public int Write(ObservationWriter writer) { return 0; } public byte[] GetCompressedObservation() { return new byte[] { 13, 37 }; } public void Update() { } public void Reset() { } public SensorCompressionType GetCompressionType() { return CompressionType; } public string GetName() { return "Dummy"; } } [Test] public void TestGetObservationProtoCapabilities() { // Shape, compression type, concatenatedPngObservations, expect throw var variants = new[] { // Vector observations (new[] {3}, SensorCompressionType.None, false, false), // Uncompressed floats (new[] {4, 4, 3}, SensorCompressionType.None, false, false), // Compressed floats, 3 channels (new[] {4, 4, 3}, SensorCompressionType.PNG, false, true), // Compressed floats, >3 channels (new[] {4, 4, 4}, SensorCompressionType.PNG, false, false), // Unsupported - results in uncompressed (new[] {4, 4, 4}, SensorCompressionType.PNG, true, true), // Supported compressed }; foreach (var (shape, compressionType, supportsMultiPngObs, expectCompressed) in variants) { var dummySensor = new DummySensor(); var obsWriter = new ObservationWriter(); dummySensor.Shape = shape; dummySensor.CompressionType = compressionType; obsWriter.SetTarget(new float[128], shape, 0); var caps = new UnityRLCapabilities { ConcatenatedPngObservations = supportsMultiPngObs }; Academy.Instance.TrainerCapabilities = caps; var obsProto = dummySensor.GetObservationProto(obsWriter); if (expectCompressed) { Assert.Greater(obsProto.CompressedData.Length, 0); Assert.AreEqual(obsProto.FloatData, null); } else { Assert.Greater(obsProto.FloatData.Data.Count, 0); Assert.AreEqual(obsProto.CompressedData.Length, 0); } } } } }