using System; using System.Collections.Generic; using System.Linq; using Google.Protobuf; using Google.Protobuf.Collections; using MLAgents.CommunicatorObjects; using UnityEngine; namespace MLAgents { public static class GrpcExtensions { /// /// Converts a AgentInfo to a protobuf generated AgentInfoProto /// /// The protobuf version of the AgentInfo. public static AgentInfoProto ToProto(this AgentInfo ai) { var agentInfoProto = new AgentInfoProto { StackedVectorObservation = { ai.stackedVectorObservation }, StoredVectorActions = { ai.storedVectorActions }, StoredTextActions = ai.storedTextActions, TextObservation = ai.textObservation, Reward = ai.reward, MaxStepReached = ai.maxStepReached, Done = ai.done, Id = ai.id, CustomObservation = ai.customObservation }; if (ai.memories != null) { agentInfoProto.Memories.Add(ai.memories); } if (ai.actionMasks != null) { agentInfoProto.ActionMask.AddRange(ai.actionMasks); } foreach (var obs in ai.visualObservations) { using (TimerStack.Instance.Scoped("encodeVisualObs")) { agentInfoProto.VisualObservations.Add( ByteString.CopyFrom(obs.EncodeToPNG()) ); } } return agentInfoProto; } /// /// Converts a Brain into to a Protobuff BrainInfoProto so it can be sent /// /// The BrainInfoProto generated. /// The instance of BrainParameter to extend. /// The name of the brain. /// Whether or not the Brain is training. public static BrainParametersProto ToProto(this BrainParameters bp, string name, bool isTraining) { var brainParametersProto = new BrainParametersProto { VectorObservationSize = bp.vectorObservationSize, NumStackedVectorObservations = bp.numStackedVectorObservations, VectorActionSize = { bp.vectorActionSize }, VectorActionSpaceType = (SpaceTypeProto)bp.vectorActionSpaceType, BrainName = name, IsTraining = isTraining }; brainParametersProto.VectorActionDescriptions.AddRange(bp.vectorActionDescriptions); foreach (var res in bp.cameraResolutions) { brainParametersProto.CameraResolutions.Add( new ResolutionProto { Width = res.width, Height = res.height, GrayScale = res.blackAndWhite }); } return brainParametersProto; } /// /// Convert metadata object to proto object. /// public static DemonstrationMetaProto ToProto(this DemonstrationMetaData dm) { var demoProto = new DemonstrationMetaProto { ApiVersion = DemonstrationMetaData.ApiVersion, MeanReward = dm.meanReward, NumberSteps = dm.numberExperiences, NumberEpisodes = dm.numberEpisodes, DemonstrationName = dm.demonstrationName }; return demoProto; } /// /// Initialize metadata values based on proto object. /// public static DemonstrationMetaData ToDemonstrationMetaData(this DemonstrationMetaProto demoProto) { var dm = new DemonstrationMetaData { numberEpisodes = demoProto.NumberEpisodes, numberExperiences = demoProto.NumberSteps, meanReward = demoProto.MeanReward, demonstrationName = demoProto.DemonstrationName }; if (demoProto.ApiVersion != DemonstrationMetaData.ApiVersion) { throw new Exception("API versions of demonstration are incompatible."); } return dm; } /// /// Converts Resolution protobuf array to C# Resolution array. /// private static Resolution[] ResolutionProtoToNative(IReadOnlyList resolutionProtos) { var localCameraResolutions = new Resolution[resolutionProtos.Count]; for (var i = 0; i < resolutionProtos.Count; i++) { localCameraResolutions[i] = new Resolution { height = resolutionProtos[i].Height, width = resolutionProtos[i].Width, blackAndWhite = resolutionProtos[i].GrayScale }; } return localCameraResolutions; } /// /// Convert a BrainParametersProto to a BrainParameters struct. /// /// An instance of a brain parameters protobuf object. /// A BrainParameters struct. public static BrainParameters ToBrainParameters(this BrainParametersProto bpp) { var bp = new BrainParameters { vectorObservationSize = bpp.VectorObservationSize, cameraResolutions = ResolutionProtoToNative( bpp.CameraResolutions ), numStackedVectorObservations = bpp.NumStackedVectorObservations, vectorActionSize = bpp.VectorActionSize.ToArray(), vectorActionDescriptions = bpp.VectorActionDescriptions.ToArray(), vectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceType }; return bp; } /// /// Convert a MapField to ResetParameters. /// /// The mapping of strings to floats from a protobuf MapField. /// public static ResetParameters ToResetParameters(this MapField floatParams) { return new ResetParameters(floatParams); } /// /// Convert an EnvironmnetParametersProto protobuf object to an EnvironmentResetParameters struct. /// /// The instance of the EnvironmentParametersProto object. /// A new EnvironmentResetParameters struct. public static EnvironmentResetParameters ToEnvironmentResetParameters(this EnvironmentParametersProto epp) { return new EnvironmentResetParameters { resetParameters = epp.FloatParameters?.ToResetParameters(), customResetParameters = epp.CustomResetParameters }; } public static UnityRLInitParameters ToUnityRLInitParameters(this UnityRLInitializationInputProto inputProto) { return new UnityRLInitParameters { seed = inputProto.Seed }; } public static AgentAction ToAgentAction(this AgentActionProto aap) { return new AgentAction { vectorActions = aap.VectorActions.ToArray(), textActions = aap.TextActions, memories = aap.Memories.ToList(), value = aap.Value, customAction = aap.CustomAction }; } public static List ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto) { var agentActions = new List(proto.Value.Count); foreach (var ap in proto.Value) { agentActions.Add(ap.ToAgentAction()); } return agentActions; } } }