浏览代码

use TensorShape for index calc (#3171)

* use tensorshape for index calc

* docstring

* dont need shape anymore
/asymm-envs
GitHub 5 年前
当前提交
8cf94e1b
共有 3 个文件被更改,包括 29 次插入28 次删除
  1. 12
      UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs
  2. 5
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
  3. 40
      UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs

12
UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs


valueType = TensorProxy.TensorType.FloatingPoint,
data = new Tensor(2, 3)
};
var shape = new[] { 3 };
writer.SetTarget(t, shape, 0, 0);
writer.SetTarget(t, 0, 0);
writer.SetTarget(t, shape, 1, 1);
writer.SetTarget(t, 1, 1);
writer[0] = 2f;
writer[1] = 3f;
// [0, 0] shouldn't change

data = new Tensor(2, 3)
};
writer.SetTarget(t, shape, 1, 1);
writer.SetTarget(t, 1, 1);
writer.AddRange(new [] {-1f, -2f});
Assert.AreEqual(0f, t.data[0, 0]);
Assert.AreEqual(0f, t.data[0, 1]);

var shape = new[] { 2, 2, 3 };
writer.SetTarget(t, shape, 0, 0);
writer.SetTarget(t, 0, 0);
writer.SetTarget(t, shape, 0, 1);
writer.SetTarget(t, 0, 1);
writer[1, 0, 0] = 2f;
Assert.AreEqual(2f, t.data[0, 1, 0, 1]);
}

5
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs


foreach (var sensorIndex in m_SensorIndices)
{
var sensor = agent.sensors[sensorIndex];
var shape = sensor.GetObservationShape();
m_WriteAdapter.SetTarget(tensorProxy, shape, agentIndex, tensorOffset);
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, tensorOffset);
var numWritten = sensor.Write(m_WriteAdapter);
tensorOffset += numWritten;
}

foreach (var agent in agents)
{
var sensor = agent.sensors[m_SensorIndex];
m_WriteAdapter.SetTarget(tensorProxy, sensor.GetObservationShape(), agentIndex, 0);
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, 0);
sensor.Write(m_WriteAdapter);
agentIndex++;
}

40
UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs


using System;
using System.Collections.Generic;
using Barracuda;
using MLAgents.InferenceBrain;
namespace MLAgents.Sensor

TensorProxy m_Proxy;
int m_Batch;
int[] m_Shape;
TensorShape m_TensorShape;
/// <summary>
/// Set the adapter to write to an IList at the given channelOffset.

m_Offset = offset;
m_Proxy = null;
m_Batch = 0;
m_Shape = shape;
if (shape.Length == 1)
{
m_TensorShape = new TensorShape(m_Batch, shape[0]);
}
else
{
m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]);
}
}
/// <summary>

/// <param name="shape">Shape of the observations to be written.</param>
public void SetTarget(TensorProxy tensorProxy, int[] shape, int batchIndex, int channelOffset)
public void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset)
m_Shape = shape;
m_TensorShape = m_Proxy.data.shape;
}
/// <summary>

{
set
{
// TODO check shape is 1D?
if (m_Data != null)
{
m_Data[index + m_Offset] = value;

{
if (m_Data != null)
{
var height = m_Shape[0];
var width = m_Shape[1];
var channels = m_Shape[2];
if (h < 0 || h >= height)
if (h < 0 || h >= m_TensorShape.height)
throw new IndexOutOfRangeException($"height value {h} must be in range [0, {height-1}]");
throw new IndexOutOfRangeException($"height value {h} must be in range [0, {m_TensorShape.height-1}]");
if (w < 0 || w >= width)
if (w < 0 || w >= m_TensorShape.width)
throw new IndexOutOfRangeException($"width value {w} must be in range [0, {width-1}]");
throw new IndexOutOfRangeException($"width value {w} must be in range [0, {m_TensorShape.width-1}]");
if (ch < 0 || ch >= channels)
if (ch < 0 || ch >= m_TensorShape.channels)
throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {channels-1}]");
throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {m_TensorShape.channels-1}]");
// Math copied from TensorShape.Index(). Note that m_Batch should always be 0
var index = m_Batch * height * width * channels + h * width * channels + w * channels + ch;
m_Data[index + m_Offset] = value;
var index = m_TensorShape.Index(m_Batch, h, w, ch + m_Offset);
m_Data[index] = value;
}
else
{

正在加载...
取消
保存