|
|
|
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor.
|
|
|
|
/// </summary>
|
|
|
|
/// <param name="tensorProxy">The tensor that is expected by the model</param>
|
|
|
|
/// <param name="sensorComponent">The sensor that produces the visual observation.</param>
|
|
|
|
/// <returns>
|
|
|
|
/// If the Check failed, returns a string containing information about why the
|
|
|
|
/// check failed. If the check passed, returns null.
|
|
|
|
/// </returns>
|
|
|
|
static string CheckRankTwoObsShape( |
|
|
|
TensorProxy tensorProxy, SensorComponent sensorComponent) |
|
|
|
{ |
|
|
|
var shape = sensorComponent.GetObservationShape(); |
|
|
|
var dim1Bp = shape[0]; |
|
|
|
var dim2Bp = shape[1]; |
|
|
|
var dim1T = tensorProxy.Channels; |
|
|
|
var dim2T = tensorProxy.Width; |
|
|
|
if ((dim1Bp != dim1T) || (dim2Bp != dim2T)) |
|
|
|
{ |
|
|
|
return $"An Observation of the model does not match. " + |
|
|
|
$"Received TensorProxy of shape [?x{dim1Bp}x{dim2Bp}] but " + |
|
|
|
$"was expecting [?x{dim1T}x{dim2T}]."; |
|
|
|
} |
|
|
|
return null; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// Generates failed checks that correspond to inputs shapes incompatibilities between
|
|
|
|
/// the model and the BrainParameters.
|
|
|
|
/// </summary>
|
|
|
|
|
|
|
for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++) |
|
|
|
{ |
|
|
|
var sensorComponent = sensorComponents[sensorIndex]; |
|
|
|
if (!sensorComponent.IsVisual()) |
|
|
|
if (sensorComponent.IsVisual()) |
|
|
|
{ |
|
|
|
|
|
|
|
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] = |
|
|
|
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent); |
|
|
|
visObsIndex++; |
|
|
|
} |
|
|
|
if (sensorComponent.GetObservationShape().Length == 2) |
|
|
|
continue; |
|
|
|
tensorTester[TensorNames.ObservationPlaceholderPrefix + sensorIndex] = |
|
|
|
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sensorComponent); |
|
|
|
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] = |
|
|
|
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent); |
|
|
|
visObsIndex++; |
|
|
|
} |
|
|
|
|
|
|
|
// If the model expects an input but it is not in this list
|
|
|
|