浏览代码

Addressed comments

/develop-generalizationTraining-TrainerController
vincentpierre 6 年前
当前提交
5a44d3ee
共有 4 个文件被更改,包括 520 次插入180 次删除
  1. 2
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelParamLoader.cs
  2. 352
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TFSharpInferenceEngine.cs
  3. 224
      UnitySDK/Assets/ML-Agents/Editor/Tests/MultinomialTest.cs
  4. 122
      UnitySDK/Assets/ML-Agents/Editor/Tests/RandomNormalTest.cs

2
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelParamLoader.cs


{
_engine.ExecuteGraph(new Tensor[0], outputs);
}
catch (Exception e)
catch
{
return -1;
}

352
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TFSharpInferenceEngine.cs


namespace MLAgents.InferenceBrain
{
/// <summary>
/// TFSharpInferenceEngine - Inference engine utilizing the TensorFlow Sharp package to run inference
/// on frozen TensorFlow models
/// </summary>
public class TFSharpInferenceEngine
{
private TFGraph m_graph;
private TFSession m_session;
/// <summary>
/// TFSharpInferenceEngine - Inference engine utilizing the TensorFlow Sharp package to run inference
/// on frozen TensorFlow models
/// </summary>
public class TFSharpInferenceEngine
{
private TFGraph m_graph;
private TFSession m_session;
public void PrepareModel(byte[] model)
{
Profiler.BeginSample("TFSharpInferenceComponent.PrepareModel");
public void PrepareModel(byte[] model)
{
Profiler.BeginSample("TFSharpInferenceComponent.PrepareModel");
// This needs to ba called only once and will raise an exception if called multiple times
try{
TensorFlowSharp.Android.NativeBinding.Init();
}
catch{
// This needs to ba called only once and will raise an exception if called multiple times
try{
TensorFlowSharp.Android.NativeBinding.Init();
}
catch{
}
}
m_graph = new TFGraph();
m_graph.Import(model);
m_session = new TFSession(m_graph);
Profiler.EndSample();
}
public int ExecuteGraph(IEnumerable<Tensor> inputs_it, IEnumerable<Tensor> outputs_it)
{
Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph");
Tensor[] inputs = inputs_it.ToArray();
Tensor[] outputs = outputs_it.ToArray();
m_graph = new TFGraph();
m_graph.Import(model);
m_session = new TFSession(m_graph);
Profiler.EndSample();
}
// TODO: Can/should we pre-allocate that?
TFSession.Runner runner = m_session.GetRunner();
public int ExecuteGraph(IEnumerable<Tensor> inputs_it, IEnumerable<Tensor> outputs_it)
{
Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph");
Tensor[] inputs = inputs_it.ToArray();
Tensor[] outputs = outputs_it.ToArray();
inputs.ToList().ForEach((Tensor input) =>
{
if (input.Shape.Length == 0)
{
var data = input.Data.GetValue(0);
if (input.DataType == typeof(int))
{
runner.AddInput(m_graph[input.Name][0], (int)data);
}
else
{
runner.AddInput(m_graph[input.Name][0], (float)data);
}
}
else
{
runner.AddInput(m_graph[input.Name][0], input.Data);
}
});
// TODO: Can/should we pre-allocate that?
TFSession.Runner runner = m_session.GetRunner();
// TODO: better way to pre-allocate this?
outputs.ToList().ForEach(s => runner.Fetch(s.Name));
inputs.ToList().ForEach((Tensor input) =>
{
if (input.Shape.Length == 0)
{
var data = input.Data.GetValue(0);
if (input.DataType == typeof(int))
{
runner.AddInput(m_graph[input.Name][0], (int)data);
}
else
{
runner.AddInput(m_graph[input.Name][0], (float)data);
}
}
else
{
runner.AddInput(m_graph[input.Name][0], input.Data);
}
});
TFStatus status = new TFStatus();
Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph.RunnerRun");
var out_tensors = runner.Run(status);
Profiler.EndSample();
// TODO: better way to pre-allocate this?
outputs.ToList().ForEach(s => runner.Fetch(s.Name));
if (!status.Ok)
{
Debug.LogError(status.StatusMessage);
return -1;
}
TFStatus status = new TFStatus();
Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph.RunnerRun");
var out_tensors = runner.Run(status);
Profiler.EndSample();
Debug.Assert(outputs.Length == out_tensors.Length);
if (!status.Ok)
{
Debug.LogError(status.StatusMessage);
return -1;
}
for (var i = 0; i < outputs.Length; ++i)
{
if (outputs[i].Shape.Length == 0)
{
// Handle scalars
outputs[i].Data = Array.CreateInstance(outputs[i].DataType, new long[1] {1});
outputs[i].Data.SetValue(out_tensors[i].GetValue(), 0);
}
else
{
outputs[i].Data = out_tensors[i].GetValue() as Array;
}
}
Debug.Assert(outputs.Length == out_tensors.Length);
Profiler.EndSample();
// TODO: create error codes
return 0;
for (var i = 0; i < outputs.Length; ++i)
{
if (outputs[i].Shape.Length == 0)
{
// Handle scalars
outputs[i].Data = Array.CreateInstance(outputs[i].DataType, new long[1] {1});
outputs[i].Data.SetValue(out_tensors[i].GetValue(), 0);
else
{
outputs[i].Data = out_tensors[i].GetValue() as Array;
}
}
[DllImport("libtensorflow")]
private static extern unsafe void TF_OperationGetAttrType(IntPtr oper, string attr_name,
TFDataType* value, IntPtr status);
[DllImport("libtensorflow")]
private static extern unsafe void TF_OperationGetAttrShape(IntPtr oper, string attr_name, long[] value,
int num_dims, IntPtr status);
Profiler.EndSample();
// TODO: create error codes
return 0;
}
[DllImport("libtensorflow")]
private static extern unsafe void TF_OperationGetAttrShapeList(
IntPtr oper, string attr_name, long*[] dims, int[] num_dims, long num_shapes, long[] storage,
long storage_size, IntPtr status);
[DllImport("libtensorflow")]
private static extern unsafe void TF_OperationGetAttrType(IntPtr oper, string attr_name,
TFDataType* value, IntPtr status);
[DllImport("libtensorflow")]
private static extern unsafe void TF_OperationGetAttrShape(IntPtr oper, string attr_name, long[] value,
int num_dims, IntPtr status);
private Tensor GetOpMetadata(TFOperation op)
{
TFStatus status = new TFStatus();
private Tensor GetOpMetadata(TFOperation op)
{
TFStatus status = new TFStatus();
// Query the shape
long[] shape = null;
var shape_attr = op.GetAttributeMetadata("shape", status);
if (!status.Ok || shape_attr.TotalSize <= 0)
{
Debug.LogWarning("Operation " + op.Name + " does not contain shape attribute or it" +
" doesn't contain valid shape data!");
}
else
// Query the shape
long[] shape = null;
var shape_attr = op.GetAttributeMetadata("shape", status);
if (!status.Ok || shape_attr.TotalSize <= 0)
{
Debug.LogWarning("Operation " + op.Name + " does not contain shape attribute or it" +
" doesn't contain valid shape data!");
}
else
{
if (shape_attr.IsList)
{
throw new NotImplementedException("Querying lists is not implemented yet!");
}
else
{
TFStatus s = new TFStatus();
long[] dims = new long[shape_attr.TotalSize];
TF_OperationGetAttrShape(op.Handle, "shape", dims, (int)shape_attr.TotalSize,
s.Handle);
if (!status.Ok)
{
throw new FormatException("Could not query model for op shape (" + op.Name + ")");
}
else
{
shape = new long[dims.Length];
for (int i = 0; i < shape_attr.TotalSize; ++i)
if (shape_attr.IsList)
{
throw new NotImplementedException("Querying lists is not implemented yet!");
}
else
{
TFStatus s = new TFStatus();
long[] dims = new long[shape_attr.TotalSize];
TF_OperationGetAttrShape(op.Handle, "shape", dims, (int)shape_attr.TotalSize,
s.Handle);
if (!status.Ok)
{
throw new FormatException("Could not query model for op shape (" + op.Name + ")");
}
else
{
shape = new long[dims.Length];
for (int i = 0; i < shape_attr.TotalSize; ++i)
{
if (dims[i] == -1)
{
// TODO: since CoreML doesn't support batching yet
// we have to use batchsize 1
shape[i] = 1;
}
else
{
shape[i] = dims[i];
}
}
}
}
if (dims[i] == -1)
{
// we have to use batchsize 1
shape[i] = 1;
}
else
{
shape[i] = dims[i];
}
}
}
}
// Query the data type
TFDataType type_value = new TFDataType();
unsafe
{
TFStatus s = new TFStatus();
TF_OperationGetAttrType(op.Handle, "dtype", &type_value, s.Handle);
if (!s.Ok)
{
Debug.LogWarning("Operation " + op.Name +
": error retrieving dtype, assuming float!");
type_value = TFDataType.Float;
}
}
// Query the data type
TFDataType type_value = new TFDataType();
unsafe
{
TFStatus s = new TFStatus();
TF_OperationGetAttrType(op.Handle, "dtype", &type_value, s.Handle);
if (!s.Ok)
{
Debug.LogWarning("Operation " + op.Name +
": error retrieving dtype, assuming float!");
type_value = TFDataType.Float;
}
}
Tensor.TensorType placeholder_type = Tensor.TensorType.FloatingPoint;
switch (type_value)
{
case TFDataType.Float:
placeholder_type = Tensor.TensorType.FloatingPoint;
break;
case TFDataType.Int32:
placeholder_type = Tensor.TensorType.Integer;
break;
default:
Debug.LogWarning("Operation " + op.Name +
" is not a float/integer. Proceed at your own risk!");
break;
}
Tensor.TensorType placeholder_type = Tensor.TensorType.FloatingPoint;
switch (type_value)
{
case TFDataType.Float:
placeholder_type = Tensor.TensorType.FloatingPoint;
break;
case TFDataType.Int32:
placeholder_type = Tensor.TensorType.Integer;
break;
default:
Debug.LogWarning("Operation " + op.Name +
" is not a float/integer. Proceed at your own risk!");
break;
}
Tensor t = new Tensor
{
Data = null,
Name = op.Name,
Shape = shape,
ValueType = placeholder_type
};
return t;
}
Tensor t = new Tensor
{
Data = null,
Name = op.Name,
Shape = shape,
ValueType = placeholder_type
};
return t;
}
public IEnumerable<Tensor> InputFeatures()
public IEnumerable<Tensor> InputFeatures()
{
List<Tensor> inputs = new List<Tensor>();
foreach (var op in m_graph.GetEnumerator())
{
if (op.OpType == "Placeholder")
List<Tensor> inputs = new List<Tensor>();
foreach (var op in m_graph.GetEnumerator())
{
if (op.OpType == "Placeholder")
{
inputs.Add(GetOpMetadata(op));
}
}
inputs.Add(GetOpMetadata(op));
}
}
return inputs;
}
return inputs;
}
}
#endif

224
UnitySDK/Assets/ML-Agents/Editor/Tests/MultinomialTest.cs


using System;
using NUnit.Framework;
using UnityEngine;
using MLAgents.InferenceBrain;
using MLAgents.InferenceBrain.Utils;
namespace MLAgents.Tests
{
public class MultinomialTest
{
[Test]
public void TestEvalP()
{
Multinomial m = new Multinomial(2018);
Tensor src = new Tensor
{
Data = new float[1, 3] {{0.1f, 0.2f, 0.7f}},
ValueType = Tensor.TensorType.FloatingPoint
};
Tensor dst = new Tensor
{
Data = new float[1, 3],
ValueType = Tensor.TensorType.FloatingPoint
};
m.Eval(src, dst);
float[] reference = {2, 2, 1};
int i = 0;
foreach (var f in dst.Data)
{
Assert.AreEqual(reference[i], f);
++i;
}
}
[Test]
public void TestEvalLogits()
{
Multinomial m = new Multinomial(2018);
Tensor src = new Tensor
{
Data = new float[1, 3] {{Mathf.Log(0.1f) - 50, Mathf.Log(0.2f) - 50, Mathf.Log(0.7f) - 50}},
ValueType = Tensor.TensorType.FloatingPoint
};
Tensor dst = new Tensor
{
Data = new float[1, 3],
ValueType = Tensor.TensorType.FloatingPoint
};
m.Eval(src, dst);
float[] reference = {2, 2, 2};
int i = 0;
foreach (var f in dst.Data)
{
Assert.AreEqual(reference[i], f);
++i;
}
}
[Test]
public void TestEvalBatching()
{
Multinomial m = new Multinomial(2018);
Tensor src = new Tensor
{
Data = new float[2, 3]
{
{Mathf.Log(0.1f) - 50, Mathf.Log(0.2f) - 50, Mathf.Log(0.7f) - 50},
{Mathf.Log(0.3f) - 25, Mathf.Log(0.4f) - 25, Mathf.Log(0.3f) - 25},
},
ValueType = Tensor.TensorType.FloatingPoint
};
Tensor dst = new Tensor
{
Data = new float[2, 3],
ValueType = Tensor.TensorType.FloatingPoint
};
m.Eval(src, dst);
float[] reference = {2, 2, 2, 0, 1, 0};
int i = 0;
foreach (var f in dst.Data)
{
Assert.AreEqual(reference[i], f);
++i;
}
}
[Test]
public void TestSrcInt()
{
Multinomial m = new Multinomial(2018);
Tensor src = new Tensor
{
ValueType = Tensor.TensorType.Integer
};
Assert.Throws<NotImplementedException>(() => m.Eval(src, null));
}
[Test]
public void TestDstInt()
{
Multinomial m = new Multinomial(2018);
Tensor src = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint
};
Tensor dst = new Tensor
{
ValueType = Tensor.TensorType.Integer
};
Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
}
[Test]
public void TestSrcDataNull()
{
Multinomial m = new Multinomial(2018);
Tensor src = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint
};
Tensor dst = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint
};
Assert.Throws<ArgumentNullException>(() => m.Eval(src, dst));
}
[Test]
public void TestDstDataNull()
{
Multinomial m = new Multinomial(2018);
Tensor src = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint,
Data = new float[1]
};
Tensor dst = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint
};
Assert.Throws<ArgumentNullException>(() => m.Eval(src, dst));
}
[Test]
public void TestSrcWrongShape()
{
Multinomial m = new Multinomial(2018);
Tensor src = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint,
Data = new float[1]
};
Tensor dst = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint,
Data = new float[1]
};
Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
}
[Test]
public void TestDstWrongShape()
{
Multinomial m = new Multinomial(2018);
Tensor src = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint,
Data = new float[1, 1]
};
Tensor dst = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint,
Data = new float[1]
};
Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
}
[Test]
public void TestUnequalBatchSize()
{
Multinomial m = new Multinomial(2018);
Tensor src = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint,
Data = new float[1, 1]
};
Tensor dst = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint,
Data = new float[2, 1]
};
Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
}
}
}

122
UnitySDK/Assets/ML-Agents/Editor/Tests/RandomNormalTest.cs


using System;
using NUnit.Framework;
using MLAgents.InferenceBrain;
using MLAgents.InferenceBrain.Utils;
namespace MLAgents.Tests
{
public class RandomNormalTest
{
[Test]
public void RandomNormalTestTwoDouble()
{
RandomNormal rn = new RandomNormal(2018);
Assert.AreEqual(-0.46666, rn.NextDouble(), 0.0001);
Assert.AreEqual(-0.37989, rn.NextDouble(), 0.0001);
}
[Test]
public void RandomNormalTestWithMean()
{
RandomNormal rn = new RandomNormal(2018, 5.0f);
Assert.AreEqual(4.53333, rn.NextDouble(), 0.0001);
Assert.AreEqual(4.6201, rn.NextDouble(), 0.0001);
}
[Test]
public void RandomNormalTestWithStddev()
{
RandomNormal rn = new RandomNormal(2018, 1.0f, 4.2f);
Assert.AreEqual(-0.9599, rn.NextDouble(), 0.0001);
Assert.AreEqual(-0.5955, rn.NextDouble(), 0.0001);
}
[Test]
public void RandomNormalTestWithMeanStddev()
{
RandomNormal rn = new RandomNormal(2018, -3.2f, 2.2f);
Assert.AreEqual(-4.2266, rn.NextDouble(), 0.0001);
Assert.AreEqual(-4.0357, rn.NextDouble(), 0.0001);
}
[Test]
public void RandomNormalTestTensorInt()
{
RandomNormal rn = new RandomNormal(1982);
Tensor t = new Tensor
{
ValueType = Tensor.TensorType.Integer
};
Assert.Throws<NotImplementedException>(() => rn.FillTensor(t));
}
[Test]
public void RandomNormalTestDataNull()
{
RandomNormal rn = new RandomNormal(1982);
Tensor t = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint
};
Assert.Throws<ArgumentNullException>(() => rn.FillTensor(t));
}
[Test]
public void RandomNormalTestTensor()
{
RandomNormal rn = new RandomNormal(1982);
Tensor t = new Tensor
{
ValueType = Tensor.TensorType.FloatingPoint,
Data = Array.CreateInstance(typeof(float), new long[3] {3, 4, 2})
};
rn.FillTensor(t);
float[] reference = new float[]
{
-0.2139822f,
0.5051259f,
-0.5640336f,
-0.3357787f,
-0.2055894f,
-0.09432302f,
-0.01419199f,
0.53621f,
-0.5507085f,
-0.2651141f,
0.09315512f,
-0.04918706f,
-0.179625f,
0.2280539f,
0.1883962f,
0.4047216f,
0.1704049f,
0.5050544f,
-0.3365685f,
0.3542781f,
0.5951571f,
0.03460682f,
-0.5537263f,
-0.4378373f,
};
int i = 0;
foreach (float f in t.Data)
{
Assert.AreEqual(f, reference[i], 0.0001);
++i;
}
}
}
}
正在加载...
取消
保存