浏览代码

Merge pull request #4843 from Unity-Technologies/develop-inference-rank-2

Preliminary work for inference with attention
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
eb630ca3
共有 8 个文件被更改,包括 115 次插入115 次删除
  1. 70
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  2. 94
      com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs
  3. 56
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  4. 1
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  5. 3
      com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs
  6. 2
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs
  7. 2
      com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs
  8. 2
      com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs

70
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
{
var sensor = sensorComponents[sensorIndex];
if (!sensor.IsVisual())
if (sensor.GetObservationShape().Length == 3)
continue;
if (!tensorsNames.Contains(
TensorNames.VisualObservationPlaceholderPrefix + visObsIndex))
{
failedModelChecks.Add(
"The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name}).");
}
visObsIndex++;
if (!tensorsNames.Contains(
TensorNames.VisualObservationPlaceholderPrefix + visObsIndex))
if (sensor.GetObservationShape().Length == 2)
failedModelChecks.Add(
"The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name}).");
if (!tensorsNames.Contains(
TensorNames.ObservationPlaceholderPrefix + sensorIndex))
{
failedModelChecks.Add(
"The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name}).");
}
visObsIndex++;
}
var expectedVisualObs = model.GetNumVisualInputs();

}
/// <summary>
/// Checks that the shape of the rank 2 observation input placeholder is the same as the corresponding sensor.
/// </summary>
/// <param name="tensorProxy">The tensor that is expected by the model</param>
/// <param name="sensorComponent">The sensor that produces the visual observation.</param>
/// <returns>
/// If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.
/// </returns>
static string CheckRankTwoObsShape(
TensorProxy tensorProxy, SensorComponent sensorComponent)
{
var shape = sensorComponent.GetObservationShape();
var dim1Bp = shape[0];
var dim2Bp = shape[1];
var dim1T = tensorProxy.Channels;
var dim2T = tensorProxy.Width;
if ((dim1Bp != dim1T) || (dim2Bp != dim2T))
{
return $"An Observation of the model does not match. " +
$"Received TensorProxy of shape [?x{dim1Bp}x{dim2Bp}] but " +
$"was expecting [?x{dim1T}x{dim2T}].";
}
return null;
}
/// <summary>
/// Generates failed checks that correspond to inputs shapes incompatibilities between
/// the model and the BrainParameters.
/// </summary>

for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
{
var sensorComponent = sensorComponents[sensorIndex];
if (!sensorComponent.IsVisual())
if (sensorComponent.GetObservationShape().Length == 3)
{
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent);
visObsIndex++;
}
if (sensorComponent.GetObservationShape().Length == 2)
continue;
tensorTester[TensorNames.ObservationPlaceholderPrefix + sensorIndex] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sensorComponent);
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
(bp, tensor, scs, i) => CheckVisualObsShape(tensor, sensorComponent);
visObsIndex++;
}
// If the model expects an input but it is not in this list

var totalVectorSensorSize = 0;
foreach (var sensorComp in sensorComponents)
{
if (sensorComp.IsVector())
if (sensorComp.GetObservationShape().Length == 1)
{
totalVectorSensorSize += sensorComp.GetObservationShape()[0];
}

var sensorSizes = "";
foreach (var sensorComp in sensorComponents)
{
if (sensorComp.IsVector())
if (sensorComp.GetObservationShape().Length == 1)
{
var vecSize = sensorComp.GetObservationShape()[0];
if (sensorSizes.Length == 0)

94
com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs


}
/// <summary>
/// Generates the Tensor corresponding to the VectorObservation input : Will be a two
/// dimensional float array of dimension [batchSize x vectorObservationSize].
/// It will use the Vector Observation data contained in the agentInfo to fill the data
/// of the tensor.
/// </summary>
internal class VectorObservationGenerator : TensorGenerator.IGenerator
{
readonly ITensorAllocator m_Allocator;
List<int> m_SensorIndices = new List<int>();
ObservationWriter m_ObservationWriter = new ObservationWriter();
public VectorObservationGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
public void AddSensorIndex(int sensorIndex)
{
m_SensorIndices.Add(sensorIndex);
}
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
var vecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1];
var agentIndex = 0;
foreach (var info in infos)
{
if (info.agentInfo.done)
{
// If the agent is done, we might have a stale reference to the sensors
// e.g. a dependent object might have been disposed.
// To avoid this, just fill observation with zeroes instead of calling sensor.Write.
TensorUtils.FillTensorBatch(tensorProxy, agentIndex, 0.0f);
}
else
{
var tensorOffset = 0;
// Write each sensor consecutively to the tensor
foreach (var sensorIndex in m_SensorIndices)
{
var sensor = info.sensors[sensorIndex];
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, tensorOffset);
var numWritten = sensor.Write(m_ObservationWriter);
tensorOffset += numWritten;
}
Debug.AssertFormat(
tensorOffset == vecObsSizeT,
"mismatch between vector observation size ({0}) and number of observations written ({1})",
vecObsSizeT, tensorOffset
);
}
agentIndex++;
}
}
}
/// <summary>
/// Generates the Tensor corresponding to the Recurrent input : Will be a two
/// dimensional float array of dimension [batchSize x memorySize].
/// It will use the Memory data contained in the agentInfo to fill the data

}
/// <summary>
/// Generates the Tensor corresponding to the Visual Observation input : Will be a 4
/// dimensional float array of dimension [batchSize x width x height x numChannels].
/// It will use the Texture input data contained in the agentInfo to fill the data
/// Generates the Tensor corresponding to the Observation input : Will be a multi
/// dimensional float array.
/// It will use the Observation data contained in the sensors to fill the data
internal class VisualObservationInputGenerator : TensorGenerator.IGenerator
internal class ObservationGenerator : TensorGenerator.IGenerator
readonly int m_SensorIndex;
List<int> m_SensorIndices = new List<int>();
public VisualObservationInputGenerator(
int sensorIndex, ITensorAllocator allocator)
public ObservationGenerator(ITensorAllocator allocator)
m_SensorIndex = sensorIndex;
public void AddSensorIndex(int sensorIndex)
{
m_SensorIndices.Add(sensorIndex);
}
foreach (var infoSensorPair in infos)
foreach (var info in infos)
var sensor = infoSensorPair.sensors[m_SensorIndex];
if (infoSensorPair.agentInfo.done)
if (info.agentInfo.done)
{
// If the agent is done, we might have a stale reference to the sensors
// e.g. a dependent object might have been disposed.

else
{
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, 0);
sensor.Write(m_ObservationWriter);
var tensorOffset = 0;
// Write each sensor consecutively to the tensor
foreach (var sensorIndex in m_SensorIndices)
{
var sensor = info.sensors[sensorIndex];
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, tensorOffset);
var numWritten = sensor.Write(m_ObservationWriter);
tensorOffset += numWritten;
}
}
agentIndex++;
}

56
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


public void InitializeObservations(List<ISensor> sensors, ITensorAllocator allocator)
{
// Loop through the sensors on a representative agent.
// For vector observations, add the index to the (single) VectorObservationGenerator
// For visual observations, make a VisualObservationInputGenerator
// All vector observations use a shared ObservationGenerator since they are concatenated.
// All other observations use a unique ObservationInputGenerator
VectorObservationGenerator vecObsGen = null;
ObservationGenerator vecObsGen = null;
// TODO generalize - we currently only have vector or visual, but can't handle "2D" observations
var isVectorSensor = (shape.Length == 1);
if (isVectorSensor)
{
if (vecObsGen == null)
{
vecObsGen = new VectorObservationGenerator(allocator);
}
vecObsGen.AddSensorIndex(sensorIndex);
}
else
var rank = shape.Length;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)
m_Dict[TensorNames.VisualObservationPlaceholderPrefix + visIndex] =
new VisualObservationInputGenerator(sensorIndex, allocator);
visIndex++;
case 1:
if (vecObsGen == null)
{
vecObsGen = new ObservationGenerator(allocator);
}
obsGen = vecObsGen;
obsGenName = TensorNames.VectorObservationPlaceholder;
break;
case 2:
// If the tensor is of rank 2, we use the index of the sensor
// to create the name
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.ObservationPlaceholderPrefix + sensorIndex;
break;
case 3:
// If the tensor is of rank 3, we use the "visual observation
// index", which only counts the rank 3 sensors
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.VisualObservationPlaceholderPrefix + visIndex;
visIndex++;
break;
default:
throw new UnityAgentsException(
$"Sensor {sensor.GetName()} have an invalid rank {rank}");
}
if (vecObsGen != null)
{
m_Dict[TensorNames.VectorObservationPlaceholder] = vecObsGen;
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
}

1
com.unity.ml-agents/Runtime/Inference/TensorNames.cs


public const string recurrentInPlaceholderH = "recurrent_in_h";
public const string recurrentInPlaceholderC = "recurrent_in_c";
public const string VisualObservationPlaceholderPrefix = "visual_observation_";
public const string ObservationPlaceholderPrefix = "obs_";
public const string PreviousActionPlaceholder = "prev_action";
public const string ActionMaskPlaceholder = "action_masks";
public const string RandomNormalEpsilonPlaceholder = "epsilon";

3
com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs


using UnityEngine;
using System;
namespace Unity.MLAgents.Sensors
{

/// Whether the observation is visual or not.
/// </summary>
/// <returns>True if the observation is visual, false otherwise.</returns>
[Obsolete("IsVisual is deprecated, please use GetObservationShape() instead.")]
public virtual bool IsVisual()
{
var shape = GetObservationShape();

/// Whether the observation is vector or not.
/// </summary>
/// <returns>True if the observation is vector, false otherwise.</returns>
[Obsolete("IsVisual is deprecated, please use GetObservationShape() instead.")]
public virtual bool IsVector()
{
var shape = GetObservationShape();

2
com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs


const int batchSize = 4;
var agentInfos = GetFakeAgents(ObservableAttributeOptions.ExamineAll);
var alloc = new TensorCachingAllocator();
var generator = new VectorObservationGenerator(alloc);
var generator = new ObservationGenerator(alloc);
generator.AddSensorIndex(0); // ObservableAttribute (size 1)
generator.AddSensorIndex(1); // TestSensor (size 0)
generator.AddSensorIndex(2); // TestSensor (size 0)

2
com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs


var expectedShape = new[] { height, width, grayscale ? 1 : 3 };
Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape());
Assert.IsTrue(cameraComponent.IsVisual());
Assert.IsFalse(cameraComponent.IsVector());
var sensor = cameraComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());

2
com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs


var expectedShape = new[] { height, width, grayscale ? 1 : 3 };
Assert.AreEqual(expectedShape, renderTexComponent.GetObservationShape());
Assert.IsTrue(renderTexComponent.IsVisual());
Assert.IsFalse(renderTexComponent.IsVector());
var sensor = renderTexComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());

正在加载...
取消
保存