using System;
using System.Collections.Generic;
using System.Linq;
using Google.Protobuf;
using Google.Protobuf.Collections;
using MLAgents.CommunicatorObjects;
using MLAgents.Sensor;
using UnityEngine;
namespace MLAgents
{
public static class GrpcExtensions
{
///
/// Converts a AgentInfo to a protobuf generated AgentInfoProto
///
/// The protobuf version of the AgentInfo.
public static AgentInfoProto ToProto(this AgentInfo ai)
{
var agentInfoProto = new AgentInfoProto
{
StackedVectorObservation = { ai.stackedVectorObservation },
StoredVectorActions = { ai.storedVectorActions },
StoredTextActions = ai.storedTextActions,
TextObservation = ai.textObservation,
Reward = ai.reward,
MaxStepReached = ai.maxStepReached,
Done = ai.done,
Id = ai.id,
CustomObservation = ai.customObservation
};
if (ai.memories != null)
{
agentInfoProto.Memories.Add(ai.memories);
}
if (ai.actionMasks != null)
{
agentInfoProto.ActionMask.AddRange(ai.actionMasks);
}
if (ai.compressedObservations != null)
{
foreach (var obs in ai.compressedObservations)
{
agentInfoProto.CompressedObservations.Add(obs.ToProto());
}
}
return agentInfoProto;
}
///
/// Converts a Brain into to a Protobuf BrainInfoProto so it can be sent
///
/// The BrainInfoProto generated.
/// The instance of BrainParameter to extend.
/// The name of the brain.
/// Whether or not the Brain is training.
public static BrainParametersProto ToProto(this BrainParameters bp, string name, bool isTraining)
{
var brainParametersProto = new BrainParametersProto
{
VectorObservationSize = bp.vectorObservationSize,
NumStackedVectorObservations = bp.numStackedVectorObservations,
VectorActionSize = { bp.vectorActionSize },
VectorActionSpaceType =
(SpaceTypeProto)bp.vectorActionSpaceType,
BrainName = name,
IsTraining = isTraining
};
brainParametersProto.VectorActionDescriptions.AddRange(bp.vectorActionDescriptions);
return brainParametersProto;
}
///
/// Convert metadata object to proto object.
///
public static DemonstrationMetaProto ToProto(this DemonstrationMetaData dm)
{
var demoProto = new DemonstrationMetaProto
{
ApiVersion = DemonstrationMetaData.ApiVersion,
MeanReward = dm.meanReward,
NumberSteps = dm.numberExperiences,
NumberEpisodes = dm.numberEpisodes,
DemonstrationName = dm.demonstrationName
};
return demoProto;
}
///
/// Initialize metadata values based on proto object.
///
public static DemonstrationMetaData ToDemonstrationMetaData(this DemonstrationMetaProto demoProto)
{
var dm = new DemonstrationMetaData
{
numberEpisodes = demoProto.NumberEpisodes,
numberExperiences = demoProto.NumberSteps,
meanReward = demoProto.MeanReward,
demonstrationName = demoProto.DemonstrationName
};
if (demoProto.ApiVersion != DemonstrationMetaData.ApiVersion)
{
throw new Exception("API versions of demonstration are incompatible.");
}
return dm;
}
///
/// Convert a BrainParametersProto to a BrainParameters struct.
///
/// An instance of a brain parameters protobuf object.
/// A BrainParameters struct.
public static BrainParameters ToBrainParameters(this BrainParametersProto bpp)
{
var bp = new BrainParameters
{
vectorObservationSize = bpp.VectorObservationSize,
numStackedVectorObservations = bpp.NumStackedVectorObservations,
vectorActionSize = bpp.VectorActionSize.ToArray(),
vectorActionDescriptions = bpp.VectorActionDescriptions.ToArray(),
vectorActionSpaceType = (SpaceType)bpp.VectorActionSpaceType
};
return bp;
}
///
/// Convert a MapField to ResetParameters.
///
/// The mapping of strings to floats from a protobuf MapField.
///
public static ResetParameters ToResetParameters(this MapField floatParams)
{
return new ResetParameters(floatParams);
}
///
/// Convert an EnvironmnetParametersProto protobuf object to an EnvironmentResetParameters struct.
///
/// The instance of the EnvironmentParametersProto object.
/// A new EnvironmentResetParameters struct.
public static EnvironmentResetParameters ToEnvironmentResetParameters(this EnvironmentParametersProto epp)
{
return new EnvironmentResetParameters
{
resetParameters = epp.FloatParameters?.ToResetParameters(),
customResetParameters = epp.CustomResetParameters
};
}
public static UnityRLInitParameters ToUnityRLInitParameters(this UnityRLInitializationInputProto inputProto)
{
return new UnityRLInitParameters
{
seed = inputProto.Seed
};
}
public static AgentAction ToAgentAction(this AgentActionProto aap)
{
return new AgentAction
{
vectorActions = aap.VectorActions.ToArray(),
textActions = aap.TextActions,
memories = aap.Memories.ToList(),
value = aap.Value,
customAction = aap.CustomAction
};
}
public static List ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto)
{
var agentActions = new List(proto.Value.Count);
foreach (var ap in proto.Value)
{
agentActions.Add(ap.ToAgentAction());
}
return agentActions;
}
public static CompressedObservationProto ToProto(this CompressedObservation obs)
{
var obsProto = new CompressedObservationProto
{
Data = ByteString.CopyFrom(obs.Data),
CompressionType = (CompressionTypeProto) obs.CompressionType,
};
obsProto.Shape.AddRange(obs.Shape);
return obsProto;
}
}
}