浏览代码

Fix visual observations with onnx (#4475) (#4479)

* handle visual observations with onnx

* test tensor resize
/release_7_branch
GitHub 4 年前
当前提交
f915b88f
共有 3 个文件被更改,包括 74 次插入7 次删除
  1. 6
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  2. 23
      com.unity.ml-agents/Runtime/Inference/TensorProxy.cs
  3. 52
      com.unity.ml-agents/Tests/Editor/TensorUtilsTest.cs

6
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


var heightBp = shape[0];
var widthBp = shape[1];
var pixelBp = shape[2];
var heightT = tensorProxy.shape[1];
var widthT = tensorProxy.shape[2];
var pixelT = tensorProxy.shape[3];
var heightT = tensorProxy.Height;
var widthT = tensorProxy.Width;
var pixelT = tensorProxy.Channels;
if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT))
{
return $"The visual Observation of the model does not match. " +

23
com.unity.ml-agents/Runtime/Inference/TensorProxy.cs


public Type DataType => k_TypeMap[valueType];
public long[] shape;
public Tensor data;
public long Height
{
get { return shape.Length == 4 ? shape[1] : shape[5]; }
}
public long Width
{
get { return shape.Length == 4 ? shape[2] : shape[6]; }
}
public long Channels
{
get { return shape.Length == 4 ? shape[3] : shape[7]; }
}
}
internal static class TensorUtils

tensor.data?.Dispose();
tensor.shape[0] = batch;
if (tensor.shape.Length == 4)
if (tensor.shape.Length == 4 || tensor.shape.Length == 8)
(int)tensor.shape[1],
(int)tensor.shape[2],
(int)tensor.shape[3]));
(int)tensor.Height,
(int)tensor.Width,
(int)tensor.Channels));
}
else
{

52
com.unity.ml-agents/Tests/Editor/TensorUtilsTest.cs


{
public class TensorUtilsTest
{
[TestCase(4, TestName = "TestResizeTensor_4D")]
[TestCase(8, TestName = "TestResizeTensor_8D")]
public void TestResizeTensor(int dimension)
{
var alloc = new TensorCachingAllocator();
var height = 64;
var width = 84;
var channels = 3;
// Set shape to {1, ..., height, width, channels}
// For 8D, the ... are all 1's
var shape = new long[dimension];
for (var i = 0; i < dimension; i++)
{
shape[i] = 1;
}
shape[dimension - 3] = height;
shape[dimension - 2] = width;
shape[dimension - 1] = channels;
var intShape = new int[dimension];
for (var i = 0; i < dimension; i++)
{
intShape[i] = (int)shape[i];
}
var tensorProxy = new TensorProxy
{
valueType = TensorProxy.TensorType.Integer,
data = new Tensor(intShape),
shape = shape,
};
// These should be invariant after the resize.
Assert.AreEqual(height, tensorProxy.data.shape.height);
Assert.AreEqual(width, tensorProxy.data.shape.width);
Assert.AreEqual(channels, tensorProxy.data.shape.channels);
TensorUtils.ResizeTensor(tensorProxy, 42, alloc);
Assert.AreEqual(height, tensorProxy.shape[dimension - 3]);
Assert.AreEqual(width, tensorProxy.shape[dimension - 2]);
Assert.AreEqual(channels, tensorProxy.shape[dimension - 1]);
Assert.AreEqual(height, tensorProxy.data.shape.height);
Assert.AreEqual(width, tensorProxy.data.shape.width);
Assert.AreEqual(channels, tensorProxy.data.shape.channels);
alloc.Dispose();
}
[Test]
public void RandomNormalTestTensorInt()
{

正在加载...
取消
保存