浏览代码

adressing comments

/MLA-1734-demo-provider
vincentpierre 4 年前
当前提交
dc84fe48
共有 7 个文件被更改,包括 38 次插入88 次删除
  1. 8
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  2. 94
      com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs
  3. 16
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  4. 2
      com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs
  5. 2
      com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs
  6. 2
      com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs
  7. 2
      com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs

8
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
{
var sensor = sensorComponents[sensorIndex];
if (!sensor.IsVisual())
if (sensor.GetObservationShape().Length != 3)
{
continue;
}

for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
{
var sensorComponent = sensorComponents[sensorIndex];
if (sensorComponent.IsVisual())
if (sensorComponent.GetObservationShape().Length == 3)
{
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =

var totalVectorSensorSize = 0;
foreach (var sensorComp in sensorComponents)
{
if (sensorComp.IsVector())
if (sensorComp.GetObservationShape().Length == 1)
{
totalVectorSensorSize += sensorComp.GetObservationShape()[0];
}

var sensorSizes = "";
foreach (var sensorComp in sensorComponents)
{
if (sensorComp.IsVector())
if (sensorComp.GetObservationShape().Length == 1)
{
var vecSize = sensorComp.GetObservationShape()[0];
if (sensorSizes.Length == 0)

94
com.unity.ml-agents/Runtime/Inference/GeneratorImpl.cs


}
/// <summary>
/// Generates the Tensor corresponding to the VectorObservation input : Will be a two
/// dimensional float array of dimension [batchSize x vectorObservationSize].
/// It will use the Vector Observation data contained in the agentInfo to fill the data
/// of the tensor.
/// </summary>
internal class VectorObservationGenerator : TensorGenerator.IGenerator
{
readonly ITensorAllocator m_Allocator;
List<int> m_SensorIndices = new List<int>();
ObservationWriter m_ObservationWriter = new ObservationWriter();
public VectorObservationGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
public void AddSensorIndex(int sensorIndex)
{
m_SensorIndices.Add(sensorIndex);
}
public void Generate(TensorProxy tensorProxy, int batchSize, IEnumerable<AgentInfoSensorsPair> infos)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
var vecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1];
var agentIndex = 0;
foreach (var info in infos)
{
if (info.agentInfo.done)
{
// If the agent is done, we might have a stale reference to the sensors
// e.g. a dependent object might have been disposed.
// To avoid this, just fill observation with zeroes instead of calling sensor.Write.
TensorUtils.FillTensorBatch(tensorProxy, agentIndex, 0.0f);
}
else
{
var tensorOffset = 0;
// Write each sensor consecutively to the tensor
foreach (var sensorIndex in m_SensorIndices)
{
var sensor = info.sensors[sensorIndex];
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, tensorOffset);
var numWritten = sensor.Write(m_ObservationWriter);
tensorOffset += numWritten;
}
Debug.AssertFormat(
tensorOffset == vecObsSizeT,
"mismatch between vector observation size ({0}) and number of observations written ({1})",
vecObsSizeT, tensorOffset
);
}
agentIndex++;
}
}
}
/// <summary>
/// Generates the Tensor corresponding to the Recurrent input : Will be a two
/// dimensional float array of dimension [batchSize x memorySize].
/// It will use the Memory data contained in the agentInfo to fill the data

}
/// <summary>
/// Generates the Tensor corresponding to the Visual Observation input : Will be a 4
/// dimensional float array of dimension [batchSize x width x height x numChannels].
/// It will use the Texture input data contained in the agentInfo to fill the data
/// Generates the Tensor corresponding to the Observation input : Will be a multi
/// dimensional float array.
/// It will use the Observation data contained in the sensors to fill the data
internal class NonVectorObservationInputGenerator : TensorGenerator.IGenerator
internal class ObservationGenerator : TensorGenerator.IGenerator
readonly int m_SensorIndex;
List<int> m_SensorIndices = new List<int>();
public NonVectorObservationInputGenerator(
int sensorIndex, ITensorAllocator allocator)
public ObservationGenerator(ITensorAllocator allocator)
m_SensorIndex = sensorIndex;
public void AddSensorIndex(int sensorIndex)
{
m_SensorIndices.Add(sensorIndex);
}
foreach (var infoSensorPair in infos)
foreach (var info in infos)
var sensor = infoSensorPair.sensors[m_SensorIndex];
if (infoSensorPair.agentInfo.done)
if (info.agentInfo.done)
{
// If the agent is done, we might have a stale reference to the sensors
// e.g. a dependent object might have been disposed.

else
{
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, 0);
sensor.Write(m_ObservationWriter);
var tensorOffset = 0;
// Write each sensor consecutively to the tensor
foreach (var sensorIndex in m_SensorIndices)
{
var sensor = info.sensors[sensorIndex];
m_ObservationWriter.SetTarget(tensorProxy, agentIndex, tensorOffset);
var numWritten = sensor.Write(m_ObservationWriter);
tensorOffset += numWritten;
}
}
agentIndex++;
}

16
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


public void InitializeObservations(List<ISensor> sensors, ITensorAllocator allocator)
{
// Loop through the sensors on a representative agent.
// For vector observations, add the index to the (single) VectorObservationGenerator
// For vector observations, add the index to the (single) ObservationGenerator
VectorObservationGenerator vecObsGen = null;
ObservationGenerator vecObsGen = null;
for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
{
var sensor = sensors[sensorIndex];

case 1:
if (vecObsGen == null)
{
vecObsGen = new VectorObservationGenerator(allocator);
vecObsGen = new ObservationGenerator(allocator);
m_Dict[TensorNames.ObservationPlaceholderPrefix + sensorIndex] =
new NonVectorObservationInputGenerator(sensorIndex, allocator);
var gen = new ObservationGenerator(allocator);
gen.AddSensorIndex(sensorIndex);
m_Dict[TensorNames.ObservationPlaceholderPrefix + sensorIndex] = gen;
m_Dict[TensorNames.VisualObservationPlaceholderPrefix + visIndex] =
new NonVectorObservationInputGenerator(sensorIndex, allocator);
var visgen = new ObservationGenerator(allocator);
visgen.AddSensorIndex(sensorIndex);
m_Dict[TensorNames.VisualObservationPlaceholderPrefix + visIndex] = visgen;
visIndex++;
break;
default:

2
com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs


using UnityEngine;
using System;
namespace Unity.MLAgents.Sensors
{

/// Whether the observation is visual or not.
/// </summary>
/// <returns>True if the observation is visual, false otherwise.</returns>
[Obsolete("IsVisual is deprecated, please use GetObservationShape() instead.")]
public virtual bool IsVisual()
{
var shape = GetObservationShape();

2
com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs


const int batchSize = 4;
var agentInfos = GetFakeAgents(ObservableAttributeOptions.ExamineAll);
var alloc = new TensorCachingAllocator();
var generator = new VectorObservationGenerator(alloc);
var generator = new ObservationGenerator(alloc);
generator.AddSensorIndex(0); // ObservableAttribute (size 1)
generator.AddSensorIndex(1); // TestSensor (size 0)
generator.AddSensorIndex(2); // TestSensor (size 0)

2
com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorComponentTest.cs


var expectedShape = new[] { height, width, grayscale ? 1 : 3 };
Assert.AreEqual(expectedShape, cameraComponent.GetObservationShape());
Assert.IsTrue(cameraComponent.IsVisual());
Assert.IsFalse(cameraComponent.IsVector());
var sensor = cameraComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());

2
com.unity.ml-agents/Tests/Editor/Sensor/RenderTextureSensorComponentTests.cs


var expectedShape = new[] { height, width, grayscale ? 1 : 3 };
Assert.AreEqual(expectedShape, renderTexComponent.GetObservationShape());
Assert.IsTrue(renderTexComponent.IsVisual());
Assert.IsFalse(renderTexComponent.IsVector());
var sensor = renderTexComponent.CreateSensor();
Assert.AreEqual(expectedShape, sensor.GetObservationShape());

正在加载...
取消
保存