浏览代码

ObsType internal too, NumDimensions to Rank

/v2-staging-rebase
Chris Elion 4 年前
当前提交
404734d9
共有 7 个文件被更改,包括 23 次插入17 次删除
  1. 5
      Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta
  2. 6
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  3. 2
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  4. 2
      com.unity.ml-agents/Runtime/Sensors/ISensor.cs
  5. 15
      com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs
  6. 4
      com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
  7. 6
      com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs

5
Project/Assets/ML-Agents/Examples/Soccer/TFModels/SoccerTwos.onnx.meta


fileFormatVersion: 2
guid: 8cd4584c2f2cb4c5fb51675d364e10ec
ScriptedImporter:
internalIDToNameTable: []
fileIDToRecycleName:
11400000: main obj
11400002: model data
serializedVersion: 2
userData:
assetBundleName:
assetBundleVariant:

6
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++)
{
var sens = sensors[sensorIndex];
if (sens.GetObservationSpec().NumDimensions == 3)
if (sens.GetObservationSpec().Rank == 3)
if (sens.GetObservationSpec().NumDimensions == 2)
if (sens.GetObservationSpec().Rank == 2)
if (sens.GetObservationSpec().NumDimensions == 1)
if (sens.GetObservationSpec().Rank == 1)
{
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankOneObsShape(tensor, sens);

2
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++)
{
var sensor = sensors[sensorIndex];
var rank = sensor.GetObservationSpec().NumDimensions;
var rank = sensor.GetObservationSpec().Rank;
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)

2
com.unity.ml-agents/Runtime/Sensors/ISensor.cs


{
var obsSpec = sensor.GetObservationSpec();
var count = 1;
for (var i = 0; i < obsSpec.NumDimensions; i++)
for (var i = 0; i < obsSpec.Rank; i++)
{
count *= obsSpec.Shape[i];
}

15
com.unity.ml-agents/Runtime/Sensors/ObservationSpec.cs


/// </summary>
public struct ObservationSpec
{
internal InplaceArray<int> m_Shape;
internal readonly InplaceArray<int> m_Shape;
/// <summary>
/// The size of the observations that will be generated.

get => m_Shape;
}
internal InplaceArray<DimensionProperty> m_DimensionProperties;
internal readonly InplaceArray<DimensionProperty> m_DimensionProperties;
/// <summary>
/// The properties of each dimensions of the observation.

get => m_DimensionProperties;
}
internal ObservationType m_ObservationType;
public ObservationType ObservationType;
public ObservationType ObservationType
{
get => m_ObservationType;
}
public int NumDimensions
public int Rank
{
get { return Shape.Length; }
}

}
m_Shape = shape;
m_DimensionProperties = dimensionProperties;
ObservationType = observationType;
m_ObservationType = observationType;
}
}
}

4
com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs


m_CompressionMapping = ConstructStackedCompressedChannelMapping(wrapped);
}
if (m_WrappedSpec.NumDimensions != 1)
if (m_WrappedSpec.Rank != 1)
{
var wrappedShape = m_WrappedSpec.Shape;
m_tensorShape = new TensorShape(0, wrappedShape[0], wrappedShape[1], wrappedShape[2]);

// Now write the saved observations (oldest first)
var numWritten = 0;
if (m_WrappedSpec.NumDimensions == 1)
if (m_WrappedSpec.Rank == 1)
{
for (var i = 0; i < m_NumStackedObservations; i++)
{

6
com.unity.ml-agents/Tests/Editor/ObservationSpecTests.cs


public void TestVectorObsSpec()
{
var obsSpec = ObservationSpec.Vector(5);
Assert.AreEqual(1, obsSpec.NumDimensions);
Assert.AreEqual(1, obsSpec.Rank);
var shape = obsSpec.Shape;
Assert.AreEqual(1, shape.Length);

public void TestVariableLengthObsSpec()
{
var obsSpec = ObservationSpec.VariableLength(5, 6);
Assert.AreEqual(2, obsSpec.NumDimensions);
Assert.AreEqual(2, obsSpec.Rank);
var shape = obsSpec.Shape;
Assert.AreEqual(2, shape.Length);

public void TestVisualObsSpec()
{
var obsSpec = ObservationSpec.Visual(5, 6, 7);
Assert.AreEqual(3, obsSpec.NumDimensions);
Assert.AreEqual(3, obsSpec.Rank);
var shape = obsSpec.Shape;
Assert.AreEqual(3, shape.Length);

正在加载...
取消
保存