浏览代码

handle visual observations with onnx

/r2v-yamato-linux
Ruo-Ping Dong 4 年前
当前提交
e75cd7b3
共有 2 个文件被更改,包括 22 次插入7 次删除
  1. 6
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  2. 23
      com.unity.ml-agents/Runtime/Inference/TensorProxy.cs

6
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


var heightBp = shape[0];
var widthBp = shape[1];
var pixelBp = shape[2];
var heightT = tensorProxy.shape[1];
var widthT = tensorProxy.shape[2];
var pixelT = tensorProxy.shape[3];
var heightT = tensorProxy.Height;
var widthT = tensorProxy.Width;
var pixelT = tensorProxy.Channels;
if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT))
{
return $"The visual Observation of the model does not match. " +

23
com.unity.ml-agents/Runtime/Inference/TensorProxy.cs


public Type DataType => k_TypeMap[valueType];
public long[] shape;
public Tensor data;
public long Height
{
get { return shape.Length == 4 ? shape[1] : shape[5]; }
}
public long Width
{
get { return shape.Length == 4 ? shape[2] : shape[6]; }
}
public long Channels
{
get { return shape.Length == 4 ? shape[3] : shape[7]; }
}
}
internal static class TensorUtils

tensor.data?.Dispose();
tensor.shape[0] = batch;
if (tensor.shape.Length == 4)
if (tensor.shape.Length == 4 || tensor.shape.Length == 8)
(int)tensor.shape[1],
(int)tensor.shape[2],
(int)tensor.shape[3]));
(int)tensor.Height,
(int)tensor.Width,
(int)tensor.Channels));
}
else
{

正在加载...
取消
保存