using System;
using System.Collections.Generic;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Analytics
{
internal struct InferenceEvent
{
///
/// Hash of the BehaviorName.
///
public string BehaviorName;
public string BarracudaModelSource;
public string BarracudaModelVersion;
public string BarracudaModelProducer;
public string BarracudaPackageVersion;
///
/// Whether inference is performed on CPU (0) or GPU (1).
///
public int InferenceDevice;
public List ObservationSpecs;
public EventActionSpec ActionSpec;
public List ActuatorInfos;
public int MemorySize;
public long TotalWeightSizeBytes;
public string ModelHash;
}
///
/// Simplified version of ActionSpec struct for use in analytics
///
[Serializable]
internal struct EventActionSpec
{
public int NumContinuousActions;
public int NumDiscreteActions;
public int[] BranchSizes;
public static EventActionSpec FromActionSpec(ActionSpec actionSpec)
{
var branchSizes = actionSpec.BranchSizes ?? Array.Empty();
return new EventActionSpec
{
NumContinuousActions = actionSpec.NumContinuousActions,
NumDiscreteActions = actionSpec.NumDiscreteActions,
BranchSizes = branchSizes,
};
}
}
///
/// Information about an actuator.
///
[Serializable]
internal struct EventActuatorInfo
{
public int BuiltInActuatorType;
public int NumContinuousActions;
public int NumDiscreteActions;
public static EventActuatorInfo FromActuator(IActuator actuator)
{
BuiltInActuatorType builtInActuatorType = Actuators.BuiltInActuatorType.Unknown;
if (actuator is IBuiltInActuator builtInActuator)
{
builtInActuatorType = builtInActuator.GetBuiltInActuatorType();
}
var actionSpec = actuator.ActionSpec;
return new EventActuatorInfo
{
BuiltInActuatorType = (int)builtInActuatorType,
NumContinuousActions = actionSpec.NumContinuousActions,
NumDiscreteActions = actionSpec.NumDiscreteActions
};
}
}
///
/// Information about one dimension of an observation.
///
[Serializable]
internal struct EventObservationDimensionInfo
{
public int Size;
public int Flags;
}
///
/// Simplified summary of Agent observations for use in analytics
///
[Serializable]
internal struct EventObservationSpec
{
public string SensorName;
public string CompressionType;
public int BuiltInSensorType;
public EventObservationDimensionInfo[] DimensionInfos;
public static EventObservationSpec FromSensor(ISensor sensor)
{
var shape = sensor.GetObservationShape();
var dimInfos = new EventObservationDimensionInfo[shape.Length];
for (var i = 0; i < shape.Length; i++)
{
dimInfos[i].Size = shape[i];
// TODO copy flags when we have them
}
var builtInSensorType =
(sensor as IBuiltInSensor)?.GetBuiltInSensorType() ?? Sensors.BuiltInSensorType.Unknown;
return new EventObservationSpec
{
SensorName = sensor.GetName(),
CompressionType = sensor.GetCompressionType().ToString(),
BuiltInSensorType = (int)builtInSensorType,
DimensionInfos = dimInfos,
};
}
}
internal struct RemotePolicyInitializedEvent
{
public string TrainingSessionGuid;
///
/// Hash of the BehaviorName.
///
public string BehaviorName;
public List ObservationSpecs;
public EventActionSpec ActionSpec;
public List ActuatorInfos;
///
/// This will be the same as TrainingEnvironmentInitializedEvent if available, but
/// TrainingEnvironmentInitializedEvent maybe not always be available with older trainers.
///
public string MLAgentsEnvsVersion;
public string TrainerCommunicationVersion;
}
internal struct TrainingEnvironmentInitializedEvent
{
public string TrainingSessionGuid;
public string TrainerPythonVersion;
public string MLAgentsVersion;
public string MLAgentsEnvsVersion;
public string TorchVersion;
public string TorchDeviceType;
public int NumEnvironments;
public int NumEnvironmentParameters;
}
[Flags]
internal enum RewardSignals
{
Extrinsic = 1 << 0,
Gail = 1 << 1,
Curiosity = 1 << 2,
Rnd = 1 << 3,
}
[Flags]
internal enum TrainingFeatures
{
BehavioralCloning = 1 << 0,
Recurrent = 1 << 1,
Threaded = 1 << 2,
SelfPlay = 1 << 3,
Curriculum = 1 << 4,
}
internal struct TrainingBehaviorInitializedEvent
{
public string TrainingSessionGuid;
public string BehaviorName;
public string TrainerType;
public RewardSignals RewardSignalFlags;
public TrainingFeatures TrainingFeatureFlags;
public string VisualEncoder;
public int NumNetworkLayers;
public int NumNetworkHiddenUnits;
}
}