浏览代码

Preliminary work for inference with attention

/MLA-1734-demo-provider
vincentpierre 4 年前
当前提交
3bc43ef2
共有 4 个文件被更改,包括 60 次插入21 次删除
  1. 41
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  2. 4
      com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs
  3. 35
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  4. 1
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs

41
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


}
/// <summary>
/// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor.
/// </summary>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensorComponent">The sensor that produces the visual observation.</param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckRankTwoObsShape(
TensorProxy tensorProxy, SensorComponent sensorComponent)
{
var shape = sensorComponent.GetObservationShape();
var dim1Bp = shape[0];
var dim2Bp = shape[1];
var dim1T = tensorProxy.Channels;
var dim2T = tensorProxy.Width;
if ((dim1Bp != dim1T) || (dim2Bp != dim2T))
{
return $"An Observation of the model does not match. " +
$"Received TensorProxy of shape [?x{dim1Bp}x{dim2Bp}] but " +
$"was expecting [?x{dim1T}x{dim2T}].";
}
return null;
}
/// <summary>
/// Generates failed checks that correspond to inputs shapes incompatibilities between
/// the model and the BrainParameters.
/// </summary>

for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
{
var sensorComponent = sensorComponents[sensorIndex];
if (!sensorComponent.IsVisual())
if (sensorComponent.IsVisual())
{
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent);
visObsIndex++;
}
if (sensorComponent.GetObservationShape().Length == 2)
continue;
tensorTester[TensorNames.ObservationPlaceholderPrefix + sensorIndex] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sensorComponent);
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent);
visObsIndex++;
}
// If the model expects an input but it is not in this list

4
com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs


/// It will use the Texture input data contained in the agentInfo to fill the data
/// of the tensor.
/// </summary>
internal class VisualObservationInputGenerator : TensorGenerator.IGenerator
internal class NonVectorObservationInputGenerator : TensorGenerator.IGenerator
public VisualObservationInputGenerator(
public NonVectorObservationInputGenerator(
int sensorIndex, ITensorAllocator allocator)
{
m_SensorIndex = sensorIndex;

35
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


{
// Loop through the sensors on a representative agent.
// For vector observations, add the index to the (single) VectorObservationGenerator
// For visual observations, make a VisualObservationInputGenerator
// For visual observations, make a NonVectorObservationInputGenerator
var visIndex = 0;
VectorObservationGenerator vecObsGen = null;
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)

// TODO generalize - we currently only have vector or visual, but can't handle "2D" observations
var isVectorSensor = (shape.Length == 1);
if (isVectorSensor)
var rank = shape.Length;
switch (rank)
if (vecObsGen == null)
{
vecObsGen = new VectorObservationGenerator(allocator);
}
case 1:
if (vecObsGen == null)
{
vecObsGen = new VectorObservationGenerator(allocator);
}
vecObsGen.AddSensorIndex(sensorIndex);
break;
case 2:
m_Dict[TensorNames.ObservationPlaceholderPrefix + sensorIndex] =
new NonVectorObservationInputGenerator(sensorIndex, allocator);
break;
case 3:
m_Dict[TensorNames.VisualObservationPlaceholderPrefix + visIndex] =
new NonVectorObservationInputGenerator(sensorIndex, allocator);
visIndex++;
break;
default:
break;
vecObsGen.AddSensorIndex(sensorIndex);
}
else
{
m_Dict[TensorNames.VisualObservationPlaceholderPrefix + visIndex] =
new VisualObservationInputGenerator(sensorIndex, allocator);
visIndex++;
}
}

1
com.unity.ml-agents/Runtime/Inference/TensorNames.cs


public const string recurrentInPlaceholderH = "recurrent_in_h";
public const string recurrentInPlaceholderC = "recurrent_in_c";
public const string VisualObservationPlaceholderPrefix = "visual_observation_";
public const string ObservationPlaceholderPrefix = "obs_";
public const string PreviousActionPlaceholder = "prev_action";
public const string ActionMaskPlaceholder = "action_masks";
public const string RandomNormalEpsilonPlaceholder = "epsilon";

正在加载...
取消
保存