/// </summary>
internal class BarracudaModelParamLoader
{
const long k_ApiVersion = 2 ;
internal enum ModelApiVersion
{
MLAgents1_0 = 2 ,
MLAgents2_0 = 3 ,
MinSupportedVersion = MLAgents1_0 ,
MaxSupportedVersion = MLAgents2_0
}
internal class FailedCheck
{
public enum CheckTypeEnum
{
Info = 0 ,
Warning = 1 ,
Error = 2
}
public CheckTypeEnum CheckType ;
public string Message ;
public static FailedCheck Info ( string message )
{
return new FailedCheck { CheckType = CheckTypeEnum . Info , Message = message } ;
}
public static FailedCheck Warning ( string message )
{
return new FailedCheck { CheckType = CheckTypeEnum . Warning , Message = message } ;
}
public static FailedCheck Error ( string message )
{
return new FailedCheck { CheckType = CheckTypeEnum . Error , Message = message } ;
}
}
/// <summary>
/// Factory for the ModelParamLoader : Creates a ModelParamLoader and runs the checks
/// <param name="actuatorComponents">Attached actuator components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <param name="behaviorType">BehaviorType or the Agent to check.</param>
/// <returns>The list the error messages of the checks that failed</returns>
public static IEnumerable < string > CheckModel ( Model model , BrainParameters brainParameters ,
ISensor [ ] sensors , ActuatorComponent [ ] actuatorComponents ,
/// <returns>A IEnumerable of the checks that failed</returns>
public static IEnumerable < FailedCheck > CheckModel (
Model model ,
BrainParameters brainParameters ,
ISensor [ ] sensors ,
ActuatorComponent [ ] actuatorComponents ,
BehaviorType behaviorType = BehaviorType . Default )
BehaviorType behaviorType = BehaviorType . Default
)
List < string > failedModelChecks = new List < string > ( ) ;
List < FailedCheck > failedModelChecks = new List < FailedCheck > ( ) ;
if ( model = = null )
{
var errorMsg = "There is no model for this Brain; cannot run inference. " ;
{
errorMsg + = "(But can still train)" ;
}
failedModelChecks . Add ( errorMsg ) ;
failedModelChecks . Add ( FailedCheck . Info ( errorMsg ) ) ;
return failedModelChecks ;
}
return failedModelChecks ;
}
var modelApiVersion = ( int ) model . GetTensorByName ( TensorNames . VersionNumber ) [ 0 ] ;
if ( modelApiVersion = = - 1 )
var modelApiVersion = model . GetVersion ( ) ;
if ( modelApiVersion < ( int ) ModelApiVersion . MinSupportedVersion | | modelApiVersion > ( int ) ModelApiVersion . MaxSupportedVersion )
"Model was not trained using the right version of ML-Agents. " +
"Cannot use this model." ) ;
return failedModelChecks ;
}
if ( modelApiVersion ! = k_ApiVersion )
{
failedModelChecks . Add (
$"Version of the trainer the model was trained with ({modelApiVersion}) " +
$"is not compatible with the Brain's version ({k_ApiVersion})." ) ;
FailedCheck . Warning ( $"Version of the trainer the model was trained with ({modelApiVersion}) " +
$"is not compatible with the current range of supported versions: " +
$"({(int)ModelApiVersion.MinSupportedVersion} to {(int)ModelApiVersion.MaxSupportedVersion})." )
) ;
return failedModelChecks ;
}
failedModelChecks . Add ( $"Missing node in the model provided : {TensorNames.MemorySize}" ) ;
failedModelChecks . Add ( FailedCheck . Warning ( $"Missing node in the model provided : {TensorNames.MemorySize}"
) ) ;
if ( modelApiVersion = = ( int ) ModelApiVersion . MLAgents1_0 )
{
failedModelChecks . AddRange (
CheckInputTensorPresenceLegacy ( model , brainParameters , memorySize , sensors )
) ;
failedModelChecks . AddRange (
CheckInputTensorShapeLegacy ( model , brainParameters , sensors , observableAttributeTotalSize )
) ;
}
else if ( modelApiVersion = = ( int ) ModelApiVersion . MLAgents2_0 )
{
failedModelChecks . AddRange (
CheckInputTensorPresence ( model , brainParameters , memorySize , sensors )
) ;
failedModelChecks . AddRange (
CheckInputTensorShape ( model , brainParameters , sensors , observableAttributeTotalSize )
) ;
}
CheckInputTensorPresence ( model , brainParameters , memorySize , sensors )
CheckOutputTensorShape ( model , brainParameters , actuatorComponents )
) ;
failedModelChecks . AddRange (
CheckInputTensorShape ( model , brainParameters , sensors , observableAttributeTotalSize )
) ;
failedModelChecks . AddRange (
CheckOutputTensorShape ( model , brainParameters , actuatorComponents )
) ;
return failedModelChecks ;
}
/// present in the BrainParameters.
/// present in the BrainParameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <param name="sensors">Array of attached sensor components</param>
/// <returns>
/// A IEnumerable of string corresponding to the failed input presence checks.
/// A IEnumerable of the checks that failed
static IEnumerable < string > CheckInputTensorPresence (
static IEnumerable < FailedCheck > CheckInputTensorPresenceLegacy (
Model model ,
BrainParameters brainParameters ,
int memory ,
var failedModelChecks = new List < string > ( ) ;
var failedModelChecks = new List < FailedCheck > ( ) ;
var tensorsNames = model . GetInputNames ( ) ;
// If there is no Vector Observation Input but the Brain Parameters expect one.
failedModelChecks . Add (
"The model does not contain a Vector Observation Placeholder Input. " +
"You must set the Vector Observation Space Size to 0." ) ;
FailedCheck . Warning ( "The model does not contain a Vector Observation Placeholder Input. " +
"You must set the Vector Observation Space Size to 0." )
) ;
}
// If there are not enough Visual Observation Input compared to what the
TensorNames . GetVisualObservationName ( visObsIndex ) ) )
{
failedModelChecks . Add (
"The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name})." ) ;
FailedCheck . Warning ( "The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name})." )
) ;
}
visObsIndex + + ;
}
TensorNames . GetObservationName ( sensorIndex ) ) )
{
failedModelChecks . Add (
"The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name})." ) ;
FailedCheck . Warning ( "The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name})." )
) ;
}
}
if ( expectedVisualObs > visObsIndex )
{
failedModelChecks . Add (
$"The model expects {expectedVisualObs} visual inputs," +
$" but only found {visObsIndex} visual sensors."
) ;
FailedCheck . Warning ( $"The model expects {expectedVisualObs} visual inputs," +
$" but only found {visObsIndex} visual sensors." )
) ;
}
// If the model has a non-negative memory size but requires a recurrent input
if ( memory > 0 )
{
if ( ! tensorsNames . Any ( x = > x . EndsWith ( "_h" ) ) | |
! tensorsNames . Any ( x = > x . EndsWith ( "_c" ) ) )
{
failedModelChecks . Add (
FailedCheck . Warning ( "The model does not contain a Recurrent Input Node but has memory_size." )
) ;
}
}
// If the model uses discrete control but does not have an input for action masks
if ( model . HasDiscreteOutputs ( ) )
{
if ( ! tensorsNames . Contains ( TensorNames . ActionMaskPlaceholder ) )
{
failedModelChecks . Add (
FailedCheck . Warning ( "The model does not contain an Action Mask but is using Discrete Control." )
) ;
}
}
return failedModelChecks ;
}
/// <summary>
/// Generates failed checks that correspond to inputs expected by the model that are not
/// present in the BrainParameters.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="memory">
/// The memory size that the model is expecting.
/// </param>
/// <param name="sensors">Array of attached sensor components</param>
/// <returns>
/// A IEnumerable of the checks that failed
/// </returns>
static IEnumerable < FailedCheck > CheckInputTensorPresence (
Model model ,
BrainParameters brainParameters ,
int memory ,
ISensor [ ] sensors
)
{
var failedModelChecks = new List < FailedCheck > ( ) ;
var tensorsNames = model . GetInputNames ( ) ;
for ( var sensorIndex = 0 ; sensorIndex < sensors . Length ; sensorIndex + + )
{
if ( ! tensorsNames . Contains (
TensorNames . GetObservationName ( sensorIndex ) ) )
{
var sensor = sensors [ sensorIndex ] ;
failedModelChecks . Add (
FailedCheck . Warning ( "The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name})." )
) ;
}
}
// If the model has a non-negative memory size but requires a recurrent input
! tensorsNames . Any ( x = > x . EndsWith ( "_c" ) ) )
{
failedModelChecks . Add (
"The model does not contain a Recurrent Input Node but has memory_size." ) ;
FailedCheck . Warning ( "The model does not contain a Recurrent Input Node but has memory_size." )
) ;
}
}
if ( ! tensorsNames . Contains ( TensorNames . ActionMaskPlaceholder ) )
{
failedModelChecks . Add (
"The model does not contain an Action Mask but is using Discrete Control." ) ;
FailedCheck . Warning ( "The model does not contain an Action Mask but is using Discrete Control." )
) ;
}
}
return failedModelChecks ;
/// </param>
/// <param name="memory">The memory size that the model is expecting/</param>
/// <returns>
/// A IEnumerable of string corresponding to the failed output presence checks.
/// A IEnumerable of the checks that failed
static IEnumerable < string > CheckOutputTensorPresence ( Model model , int memory )
static IEnumerable < FailedCheck > CheckOutputTensorPresence ( Model model , int memory )
var failedModelChecks = new List < string > ( ) ;
var failedModelChecks = new List < FailedCheck > ( ) ;
// If there is no Recurrent Output but the model is Recurrent.
if ( memory > 0 )
! memOutputs . Any ( x = > x . EndsWith ( "_c" ) ) )
{
failedModelChecks . Add (
"The model does not contain a Recurrent Output Node but has memory_size." ) ;
FailedCheck . Warning ( "The model does not contain a Recurrent Output Node but has memory_size." )
) ;
}
}
return failedModelChecks ;
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckVisualObsShape (
static FailedCheck CheckVisualObsShape (
TensorProxy tensorProxy , ISensor sensor )
{
var shape = sensor . GetObservationSpec ( ) . Shape ;
var pixelT = tensorProxy . Channels ;
if ( ( widthBp ! = widthT ) | | ( heightBp ! = heightT ) | | ( pixelBp ! = pixelT ) )
{
return $"The visual Observation of the model does not match. " +
return FailedCheck . Warning ( $"The visual Observation of the model does not match. " +
$"was expecting [?x{widthT}x{heightT}x{pixelT}]." ;
$"was expecting [?x{widthT}x{heightT}x{pixelT}] for the {sensor.GetName()} Sensor."
) ;
}
return null ;
}
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckRankTwoObsShape (
static FailedCheck CheckRankTwoObsShape (
TensorProxy tensorProxy , ISensor sensor )
{
var shape = sensor . GetObservationSpec ( ) . Shape ;
var dim2T = tensorProxy . Width ;
var dim3T = tensorProxy . Height ;
return $"An Observation of the model does not match. " +
var proxyDimStr = $"[?x{dim1T}x{dim2T}]" ;
if ( dim3T > 1 )
{
proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]" ;
}
return FailedCheck . Warning ( $"An Observation of the model does not match. " +
$"was expecting [?x{dim1T}x{dim2T}]." ;
$"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor."
) ;
}
return null ;
}
/// <summary>
/// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor.
/// </summary>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensor">The sensor that produces the visual observation.</param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static FailedCheck CheckRankOneObsShape (
TensorProxy tensorProxy , ISensor sensor )
{
var shape = sensor . GetObservationSpec ( ) . Shape ;
var dim1Bp = shape [ 0 ] ;
var dim1T = tensorProxy . Channels ;
var dim2T = tensorProxy . Width ;
var dim3T = tensorProxy . Height ;
if ( ( dim1Bp ! = dim1T ) )
{
var proxyDimStr = $"[?x{dim1T}]" ;
if ( dim2T > 1 )
{
proxyDimStr = $"[?x{dim1T}x{dim2T}]" ;
}
if ( dim3T > 1 )
{
proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]" ;
}
return FailedCheck . Warning ( $"An Observation of the model does not match. " +
$"Received TensorProxy of shape [?x{dim1Bp}] but " +
$"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor."
) ;
}
return null ;
}
/// the model and the BrainParameters.
/// the model and the BrainParameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <param name="sensors">Attached sensors</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <returns>The list the error messages of the checks that failed</returns>
static IEnumerable < string > CheckInputTensorShape (
/// <returns>A IEnumerable of the checks that failed</returns>
static IEnumerable < FailedCheck > CheckInputTensorShapeLegacy (
var failedModelChecks = new List < string > ( ) ;
var failedModelChecks = new List < FailedCheck > ( ) ;
new Dictionary < string , Func < BrainParameters , TensorProxy , ISensor [ ] , int , string > > ( )
new Dictionary < string , Func < BrainParameters , TensorProxy , ISensor [ ] , int , FailedCheck > > ( )
{ TensorNames . VectorObservationPlaceholder , CheckVectorObsShape } ,
{ TensorNames . VectorObservationPlaceholder , CheckVectorObsShapeLegacy } ,
{ TensorNames . PreviousActionPlaceholder , CheckPreviousActionShape } ,
{ TensorNames . RandomNormalEpsilonPlaceholder , ( ( bp , tensor , scs , i ) = > null ) } ,
{ TensorNames . ActionMaskPlaceholder , ( ( bp , tensor , scs , i ) = > null ) } ,
if ( ! tensor . name . Contains ( "visual_observation" ) )
{
failedModelChecks . Add (
"Model requires an unknown input named : " + tensor . name ) ;
FailedCheck . Warning ( "Model contains an unexpected input named : " + tensor . name )
) ;
}
}
else
/// <summary>
/// Checks that the shape of the Vector Observation input placeholder is the same in the
/// model and in the Brain Parameters.
/// model and in the Brain Parameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckVectorObsShape (
static FailedCheck CheckVectorObsShapeLegacy (
BrainParameters brainParameters , TensorProxy tensorProxy , ISensor [ ] sensors ,
int observableAttributeTotalSize )
{
}
sensorSizes + = "]" ;
return $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " +
return FailedCheck . Warning (
$"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " +
$"Sensor sizes: {sensorSizes}." ;
$"Sensor sizes: {sensorSizes}."
) ;
/// <summary>
/// Generates failed checks that correspond to inputs shapes incompatibilities between
/// the model and the BrainParameters.
/// </summary>
/// <param name="model">
/// The Barracuda engine model for loading static parameters
/// </param>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="sensors">Attached sensors</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <returns>A IEnumerable of the checks that failed</returns>
static IEnumerable < FailedCheck > CheckInputTensorShape (
Model model , BrainParameters brainParameters , ISensor [ ] sensors ,
int observableAttributeTotalSize )
{
var failedModelChecks = new List < FailedCheck > ( ) ;
var tensorTester =
new Dictionary < string , Func < BrainParameters , TensorProxy , ISensor [ ] , int , FailedCheck > > ( )
{
{ TensorNames . PreviousActionPlaceholder , CheckPreviousActionShape } ,
{ TensorNames . RandomNormalEpsilonPlaceholder , ( ( bp , tensor , scs , i ) = > null ) } ,
{ TensorNames . ActionMaskPlaceholder , ( ( bp , tensor , scs , i ) = > null ) } ,
{ TensorNames . SequenceLengthPlaceholder , ( ( bp , tensor , scs , i ) = > null ) } ,
{ TensorNames . RecurrentInPlaceholder , ( ( bp , tensor , scs , i ) = > null ) } ,
} ;
foreach ( var mem in model . memories )
{
tensorTester [ mem . input ] = ( ( bp , tensor , scs , i ) = > null ) ;
}
for ( var sensorIndex = 0 ; sensorIndex < sensors . Length ; sensorIndex + + )
{
var sens = sensors [ sensorIndex ] ;
if ( sens . GetObservationSpec ( ) . NumDimensions = = 3 )
{
tensorTester [ TensorNames . GetObservationName ( sensorIndex ) ] =
( bp , tensor , scs , i ) = > CheckVisualObsShape ( tensor , sens ) ;
}
if ( sens . GetObservationSpec ( ) . NumDimensions = = 2 )
{
tensorTester [ TensorNames . GetObservationName ( sensorIndex ) ] =
( bp , tensor , scs , i ) = > CheckRankTwoObsShape ( tensor , sens ) ;
}
if ( sens . GetObservationSpec ( ) . NumDimensions = = 1 )
{
tensorTester [ TensorNames . GetObservationName ( sensorIndex ) ] =
( bp , tensor , scs , i ) = > CheckRankOneObsShape ( tensor , sens ) ;
}
}
// If the model expects an input but it is not in this list
foreach ( var tensor in model . GetInputTensors ( ) )
{
if ( ! tensorTester . ContainsKey ( tensor . name ) )
{
failedModelChecks . Add ( FailedCheck . Warning ( "Model contains an unexpected input named : " + tensor . name
) ) ;
}
else
{
var tester = tensorTester [ tensor . name ] ;
var error = tester . Invoke ( brainParameters , tensor , sensors , observableAttributeTotalSize ) ;
if ( error ! = null )
{
failedModelChecks . Add ( error ) ;
}
}
}
return failedModelChecks ;
}
/// <summary>
/// Checks that the shape of the Previous Vector Action input placeholder is the same in the
/// model and in the Brain Parameters.
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes (unused).</param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
static string CheckPreviousActionShape (
static FailedCheck CheckPreviousActionShape (
BrainParameters brainParameters , TensorProxy tensorProxy ,
ISensor [ ] sensors , int observableAttributeTotalSize )
{
{
return "Previous Action Size of the model does not match. " +
$"Received {numberActionsBp} but was expecting {numberActionsT}." ;
return FailedCheck . Warning ( "Previous Action Size of the model does not match. " +
$"Received {numberActionsBp} but was expecting {numberActionsT}."
) ;
}
return null ;
}
/// </param>
/// <param name="actuatorComponents">Array of attached actuator components.</param>
/// <returns>
/// A IEnumerable of string corresponding to the incompatible shapes between model
/// A IEnumerable of error messages corresponding to the incompatible shapes between model
static IEnumerable < string > CheckOutputTensorShape (
static IEnumerable < FailedCheck > CheckOutputTensorShape (
var failedModelChecks = new List < string > ( ) ;
var failedModelChecks = new List < FailedCheck > ( ) ;
// If the model expects an output but it is not in this list
var modelContinuousActionSize = model . ContinuousOutputSize ( ) ;
failedModelChecks . Add ( continuousError ) ;
}
var modelSumDiscreteBranchSizes = model . DiscreteOutputSize ( ) ;
var discreteError = CheckDiscreteActionOutputShape ( brainParameters , actuatorComponents , modelSumDiscreteBranchSizes ) ;
FailedCheck discreteError = null ;
var modelApiVersion = model . GetVersion ( ) ;
if ( modelApiVersion = = ( int ) ModelApiVersion . MLAgents1_0 )
{
var modelSumDiscreteBranchSizes = model . DiscreteOutputSize ( ) ;
discreteError = CheckDiscreteActionOutputShapeLegacy ( brainParameters , actuatorComponents , modelSumDiscreteBranchSizes ) ;
}
if ( modelApiVersion = = ( int ) ModelApiVersion . MLAgents2_0 )
{
var modelDiscreteBranches = model . GetTensorByName ( TensorNames . DiscreteActionOutputShape ) ;
discreteError = CheckDiscreteActionOutputShape ( brainParameters , actuatorComponents , modelDiscreteBranches ) ;
}
if ( discreteError ! = null )
{
failedModelChecks . Add ( discreteError ) ;
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="actuatorComponents">Array of attached actuator components.</param>
/// <param name="modelDiscreteBranches"> The Tensor of branch sizes.
/// </param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static FailedCheck CheckDiscreteActionOutputShape (
BrainParameters brainParameters , ActuatorComponent [ ] actuatorComponents , Tensor modelDiscreteBranches )
{
var discreteActionBranches = brainParameters . ActionSpec . BranchSizes . ToList ( ) ;
foreach ( var actuatorComponent in actuatorComponents )
{
var actionSpec = actuatorComponent . ActionSpec ;
discreteActionBranches . AddRange ( actionSpec . BranchSizes ) ;
}
int modelDiscreteBranchesLength = modelDiscreteBranches ? . length ? ? 0 ;
if ( modelDiscreteBranchesLength ! = discreteActionBranches . Count )
{
return FailedCheck . Warning ( "Discrete Action Size of the model does not match. The BrainParameters expect " +
$"{discreteActionBranches.Count} branches but the model contains {modelDiscreteBranchesLength}."
) ;
}
for ( int i = 0 ; i < modelDiscreteBranchesLength ; i + + )
{
if ( modelDiscreteBranches [ i ] ! = discreteActionBranches [ i ] )
{
return FailedCheck . Warning ( $"The number of Discrete Actions of branch {i} does not match. " +
$"Was expecting {discreteActionBranches[i]} but the model contains {modelDiscreteBranches[i]} "
) ;
}
}
return null ;
}
/// <summary>
/// Checks that the shape of the discrete action output is the same in the
/// model and in the Brain Parameters. Tests the models created with the API of version 1.X
/// </summary>
/// <param name="brainParameters">
/// The BrainParameters that are used verify the compatibility with the InferenceEngine
/// </param>
/// <param name="actuatorComponents">Array of attached actuator components.</param>
/// <param name="modelSumDiscreteBranchSizes">
/// The size of the discrete action output that is expected by the model.
/// </param>
/// </returns>
static string CheckDiscreteActionOutputShape (
static FailedCheck CheckDiscreteActionOutputShapeLegacy (
BrainParameters brainParameters , ActuatorComponent [ ] actuatorComponents , int modelSumDiscreteBranchSizes )
{
// TODO: check each branch size instead of sum of branch sizes
if ( modelSumDiscreteBranchSizes ! = sumOfDiscreteBranchSizes )
{
return "Discrete Action Size of the model does not match. The BrainParameters expect " +
$"{sumOfDiscreteBranchSizes} but the model contains {modelSumDiscreteBranchSizes}." ;
return FailedCheck . Warning ( "Discrete Action Size of the model does not match. The BrainParameters expect " +
$"{sumOfDiscreteBranchSizes} but the model contains {modelSumDiscreteBranchSizes}."
) ;
}
return null ;
}
/// </param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
static string CheckContinuousActionOutputShape (
static FailedCheck CheckContinuousActionOutputShape (
BrainParameters brainParameters , ActuatorComponent [ ] actuatorComponents , int modelContinuousActionSize )
{
var numContinuousActions = brainParameters . ActionSpec . NumContinuousActions ;
if ( modelContinuousActionSize ! = numContinuousActions )
{
return "Continuous Action Size of the model does not match. The BrainParameters and ActuatorComponents expect " +
$"{numContinuousActions} but the model contains {modelContinuousActionSize}." ;
return FailedCheck . Warning (
"Continuous Action Size of the model does not match. The BrainParameters and ActuatorComponents expect " +
$"{numContinuousActions} but the model contains {modelContinuousActionSize}."
) ;
}
return null ;
}