浏览代码

test tensor resize

/r2v-yamato-linux
Ruo-Ping Dong 4 年前
当前提交
c8b9a1ea
共有 1 个文件被更改,包括 52 次插入0 次删除
  1. 52
      com.unity.ml-agents/Tests/Editor/TensorUtilsTest.cs

52
com.unity.ml-agents/Tests/Editor/TensorUtilsTest.cs


{
public class TensorUtilsTest
{
[TestCase(4, TestName = "TestResizeTensor_4D")]
[TestCase(8, TestName = "TestResizeTensor_8D")]
public void TestResizeTensor(int dimension)
{
var alloc = new TensorCachingAllocator();
var height = 64;
var width = 84;
var channels = 3;
// Set shape to {1, ..., height, width, channels}
// For 8D, the ... are all 1's
var shape = new long[dimension];
for (var i = 0; i < dimension; i++)
{
shape[i] = 1;
}
shape[dimension - 3] = height;
shape[dimension - 2] = width;
shape[dimension - 1] = channels;
var intShape = new int[dimension];
for (var i = 0; i < dimension; i++)
{
intShape[i] = (int)shape[i];
}
var tensorProxy = new TensorProxy
{
valueType = TensorProxy.TensorType.Integer,
data = new Tensor(intShape),
shape = shape,
};
// These should be invariant after the resize.
Assert.AreEqual(height, tensorProxy.data.shape.height);
Assert.AreEqual(width, tensorProxy.data.shape.width);
Assert.AreEqual(channels, tensorProxy.data.shape.channels);
TensorUtils.ResizeTensor(tensorProxy, 42, alloc);
Assert.AreEqual(height, tensorProxy.shape[dimension - 3]);
Assert.AreEqual(width, tensorProxy.shape[dimension - 2]);
Assert.AreEqual(channels, tensorProxy.shape[dimension - 1]);
Assert.AreEqual(height, tensorProxy.data.shape.height);
Assert.AreEqual(width, tensorProxy.data.shape.width);
Assert.AreEqual(channels, tensorProxy.data.shape.channels);
alloc.Dispose();
}
[Test]
public void RandomNormalTestTensorInt()
{

正在加载...
取消
保存