浏览代码
Initial commit. Modified the code to use TF# exclusively
/develop-generalizationTraining-TrainerController
Initial commit. Modified the code to use TF# exclusively
/develop-generalizationTraining-TrainerController
vincentpierre
6 年前
当前提交
52d631b7
共有 17 个文件被更改,包括 561 次插入 和 34 次删除
-
2UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs
-
2UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
-
6UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs
-
5UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
-
10UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelParamLoader.cs
-
5UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs
-
6UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs
-
35UnitySDK/Assets/ML-Agents/Scripts/LearningBrain.cs
-
236UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TFSharpInferenceEngine.cs
-
12UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TFSharpInferenceEngine.cs.meta
-
40UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Tensor.cs
-
11UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Tensor.cs.meta
-
98UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/Multinomial.cs
-
11UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/Multinomial.cs.meta
-
105UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/RandomNormal.cs
-
11UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/RandomNormal.cs.meta
|
|||
#if ENABLE_TENSORFLOW
|
|||
using System.Collections.Generic; |
|||
using TensorFlow; |
|||
using System.Linq; |
|||
using System; |
|||
using UnityEngine.Profiling; |
|||
using System.Runtime.InteropServices; |
|||
using UnityEngine; |
|||
|
|||
namespace MLAgents.InferenceBrain |
|||
{ |
|||
/// <summary>
|
|||
/// TFSharpInferenceEngine - Inference engine utilizing the TensorFlow Sharp package to run inference
|
|||
/// on frozen TensorFlow models
|
|||
/// </summary>
|
|||
public class TFSharpInferenceEngine |
|||
{ |
|||
private TFGraph m_graph; |
|||
private TFSession m_session; |
|||
|
|||
public void PrepareModel(byte[] model) |
|||
{ |
|||
Profiler.BeginSample("TFSharpInferenceComponent.PrepareModel"); |
|||
|
|||
#if UNITY_ANDROID && !UNITY_EDITOR
|
|||
// This needs to ba called only once and will raise an exception if called multiple times
|
|||
try{ |
|||
TensorFlowSharp.Android.NativeBinding.Init(); |
|||
} |
|||
catch{ |
|||
|
|||
} |
|||
#endif
|
|||
m_graph = new TFGraph(); |
|||
m_graph.Import(model); |
|||
m_session = new TFSession(m_graph); |
|||
Profiler.EndSample(); |
|||
} |
|||
|
|||
public int ExecuteGraph(IEnumerable<Tensor> inputs_it, IEnumerable<Tensor> outputs_it) |
|||
{ |
|||
Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph"); |
|||
Tensor[] inputs = inputs_it.ToArray(); |
|||
Tensor[] outputs = outputs_it.ToArray(); |
|||
|
|||
// TODO: Can/should we pre-allocate that?
|
|||
TFSession.Runner runner = m_session.GetRunner(); |
|||
|
|||
inputs.ToList().ForEach((Tensor input) => |
|||
{ |
|||
if (input.Shape.Length == 0) |
|||
{ |
|||
var data = input.Data.GetValue(0); |
|||
if (input.DataType == typeof(int)) |
|||
{ |
|||
runner.AddInput(m_graph[input.Name][0], (int)data); |
|||
} |
|||
else |
|||
{ |
|||
runner.AddInput(m_graph[input.Name][0], (float)data); |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
runner.AddInput(m_graph[input.Name][0], input.Data); |
|||
} |
|||
}); |
|||
|
|||
// TODO: better way to pre-allocate this?
|
|||
outputs.ToList().ForEach(s => runner.Fetch(s.Name)); |
|||
|
|||
TFStatus status = new TFStatus(); |
|||
Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph.RunnerRun"); |
|||
var out_tensors = runner.Run(status); |
|||
Profiler.EndSample(); |
|||
|
|||
if (!status.Ok) |
|||
{ |
|||
Debug.LogError(status.StatusMessage); |
|||
return -1; |
|||
} |
|||
|
|||
Debug.Assert(outputs.Length == out_tensors.Length); |
|||
|
|||
for (var i = 0; i < outputs.Length; ++i) |
|||
{ |
|||
if (outputs[i].Shape.Length == 0) |
|||
{ |
|||
// Handle scalars
|
|||
outputs[i].Data = Array.CreateInstance(outputs[i].DataType, new long[1] {1}); |
|||
outputs[i].Data.SetValue(out_tensors[i].GetValue(), 0); |
|||
} |
|||
else |
|||
{ |
|||
outputs[i].Data = out_tensors[i].GetValue() as Array; |
|||
} |
|||
} |
|||
|
|||
Profiler.EndSample(); |
|||
// TODO: create error codes
|
|||
return 0; |
|||
} |
|||
|
|||
[DllImport("libtensorflow")] |
|||
private static extern unsafe void TF_OperationGetAttrType(IntPtr oper, string attr_name, |
|||
TFDataType* value, IntPtr status); |
|||
|
|||
[DllImport("libtensorflow")] |
|||
private static extern unsafe void TF_OperationGetAttrShape(IntPtr oper, string attr_name, long[] value, |
|||
int num_dims, IntPtr status); |
|||
|
|||
[DllImport("libtensorflow")] |
|||
private static extern unsafe void TF_OperationGetAttrShapeList( |
|||
IntPtr oper, string attr_name, long*[] dims, int[] num_dims, long num_shapes, long[] storage, |
|||
long storage_size, IntPtr status); |
|||
|
|||
private Tensor GetOpMetadata(TFOperation op) |
|||
{ |
|||
TFStatus status = new TFStatus(); |
|||
|
|||
// Query the shape
|
|||
long[] shape = null; |
|||
var shape_attr = op.GetAttributeMetadata("shape", status); |
|||
if (!status.Ok || shape_attr.TotalSize <= 0) |
|||
{ |
|||
Debug.LogWarning("Operation " + op.Name + " does not contain shape attribute or it" + |
|||
" doesn't contain valid shape data!"); |
|||
} |
|||
else |
|||
{ |
|||
if (shape_attr.IsList) |
|||
{ |
|||
throw new NotImplementedException("Querying lists is not implemented yet!"); |
|||
} |
|||
else |
|||
{ |
|||
TFStatus s = new TFStatus(); |
|||
long[] dims = new long[shape_attr.TotalSize]; |
|||
TF_OperationGetAttrShape(op.Handle, "shape", dims, (int)shape_attr.TotalSize, |
|||
s.Handle); |
|||
if (!status.Ok) |
|||
{ |
|||
throw new FormatException("Could not query model for op shape (" + op.Name + ")"); |
|||
} |
|||
else |
|||
{ |
|||
shape = new long[dims.Length]; |
|||
for (int i = 0; i < shape_attr.TotalSize; ++i) |
|||
{ |
|||
if (dims[i] == -1) |
|||
{ |
|||
// TODO: since CoreML doesn't support batching yet
|
|||
// we have to use batchsize 1
|
|||
shape[i] = 1; |
|||
} |
|||
else |
|||
{ |
|||
shape[i] = dims[i]; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Query the data type
|
|||
TFDataType type_value = new TFDataType(); |
|||
unsafe |
|||
{ |
|||
TFStatus s = new TFStatus(); |
|||
TF_OperationGetAttrType(op.Handle, "dtype", &type_value, s.Handle); |
|||
if (!s.Ok) |
|||
{ |
|||
Debug.LogWarning("Operation " + op.Name + |
|||
": error retrieving dtype, assuming float!"); |
|||
type_value = TFDataType.Float; |
|||
} |
|||
} |
|||
|
|||
Tensor.TensorType placeholder_type = Tensor.TensorType.FloatingPoint; |
|||
switch (type_value) |
|||
{ |
|||
case TFDataType.Float: |
|||
placeholder_type = Tensor.TensorType.FloatingPoint; |
|||
break; |
|||
case TFDataType.Int32: |
|||
placeholder_type = Tensor.TensorType.Integer; |
|||
break; |
|||
default: |
|||
Debug.LogWarning("Operation " + op.Name + |
|||
" is not a float/integer. Proceed at your own risk!"); |
|||
break; |
|||
} |
|||
|
|||
Tensor t = new Tensor |
|||
{ |
|||
Data = null, |
|||
Name = op.Name, |
|||
Shape = shape, |
|||
ValueType = placeholder_type |
|||
}; |
|||
return t; |
|||
} |
|||
|
|||
public IEnumerable<Tensor> InputFeatures() |
|||
{ |
|||
List<Tensor> inputs = new List<Tensor>(); |
|||
foreach (var op in m_graph.GetEnumerator()) |
|||
{ |
|||
if (op.OpType == "Placeholder") |
|||
{ |
|||
inputs.Add(GetOpMetadata(op)); |
|||
} |
|||
} |
|||
|
|||
return inputs; |
|||
} |
|||
|
|||
public IEnumerable<Tensor> OutputFeatures() |
|||
{ |
|||
// In TF, any op can be an output
|
|||
List<Tensor> outputs = new List<Tensor>(); |
|||
foreach (var op in m_graph.GetEnumerator()) |
|||
{ |
|||
outputs.Add(GetOpMetadata(op)); |
|||
} |
|||
|
|||
return outputs; |
|||
} |
|||
|
|||
public bool AllocateOutputs() |
|||
{ |
|||
return false; |
|||
} |
|||
} |
|||
} |
|||
#endif
|
|
|||
fileFormatVersion: 2 |
|||
guid: 120cbe3fa702f4e428f57ae1d893a0a7 |
|||
timeCreated: 1535148728 |
|||
licenseType: Free |
|||
MonoImporter: |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
|
|||
namespace MLAgents.InferenceBrain |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// Tensor - A class to encapsulate a Tensor used for inference.
|
|||
///
|
|||
/// This class contains the Array that holds the data array, the shapes, type and the placeholder in the
|
|||
/// execution graph. All the fields are editable in the inspector, allowing the user to specify everything
|
|||
/// but the data in a graphical way.
|
|||
/// </summary>
|
|||
[System.Serializable] |
|||
public class Tensor |
|||
{ |
|||
public enum TensorType |
|||
{ |
|||
Integer, |
|||
FloatingPoint |
|||
}; |
|||
|
|||
private static Dictionary<TensorType, Type> m_typeMap = new Dictionary<TensorType, Type>() |
|||
{ |
|||
{ TensorType.FloatingPoint, typeof(float)}, |
|||
{TensorType.Integer, typeof(int)} |
|||
}; |
|||
|
|||
public string Name; |
|||
public TensorType ValueType; |
|||
// Since Type is not serializable, we use the DisplayType for the Inspector
|
|||
public Type DataType |
|||
{ |
|||
get { return m_typeMap[ValueType]; } |
|||
} |
|||
public long[] Shape; |
|||
public Array Data; |
|||
} |
|||
|
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 926149e757bc849689e00e12d8c6fbdb |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using Assert = UnityEngine.Assertions.Assert; |
|||
using UnityEngine; |
|||
|
|||
namespace MLAgents.InferenceBrain.Utils |
|||
{ |
|||
/// <summary>
|
|||
/// Multinomial - Draws samples from a multinomial distribution in log space
|
|||
/// Reference: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/multinomial_op.cc
|
|||
/// </summary>
|
|||
public class Multinomial |
|||
{ |
|||
private readonly System.Random m_random; |
|||
|
|||
public Multinomial(int seed) |
|||
{ |
|||
m_random = new System.Random(seed); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Draw samples from a multinomial distribution based on log-probabilities specified in tensor src. The samples
|
|||
/// will be saved in the dst tensor.
|
|||
/// </summary>
|
|||
/// <param name="src">2-D tensor with shape batch_size x num_classes</param>
|
|||
/// <param name="dst">Allocated tensor with size batch_size x num_samples</param>
|
|||
/// <exception cref="NotImplementedException">Multinomial doesn't support integer tensors</exception>
|
|||
/// <exception cref="ArgumentException">Issue with tensor shape or type</exception>
|
|||
/// <exception cref="ArgumentNullException">At least one of the tensors is not allocated</exception>
|
|||
public void Eval(Tensor src, Tensor dst) |
|||
{ |
|||
if (src.DataType != typeof(float)) |
|||
{ |
|||
throw new NotImplementedException("Multinomial does not support integer tensors yet!"); |
|||
} |
|||
|
|||
if (src.ValueType != dst.ValueType) |
|||
{ |
|||
throw new ArgumentException("Source and destination tensors have different types!"); |
|||
} |
|||
|
|||
if (src.Data == null || dst.Data == null) |
|||
{ |
|||
throw new ArgumentNullException(); |
|||
} |
|||
|
|||
float[,] input_data = src.Data as float[,]; |
|||
if (input_data == null) |
|||
{ |
|||
throw new ArgumentException("Input data is not of the correct shape! Required batch x logits"); |
|||
} |
|||
|
|||
float[,] output_data = dst.Data as float[,]; |
|||
if (output_data == null) |
|||
{ |
|||
throw new ArgumentException("Output data is not of the correct shape! Required batch x samples"); |
|||
} |
|||
|
|||
if (input_data.GetLength(0) != output_data.GetLength(0)) |
|||
{ |
|||
throw new ArgumentException("Batch size for input and output data is different!"); |
|||
} |
|||
|
|||
for (int batch = 0; batch < input_data.GetLength(0); ++batch) |
|||
{ |
|||
// Find the class maximum
|
|||
float maxProb = float.NegativeInfinity; |
|||
for (int cls = 0; cls < input_data.GetLength(1); ++cls) |
|||
{ |
|||
maxProb = Mathf.Max(input_data[batch, cls], maxProb); |
|||
} |
|||
|
|||
// Sum the log probabilities and compute CDF
|
|||
float sumProb = 0.0f; |
|||
float[] cdf = new float[input_data.GetLength(1)]; |
|||
for (int cls = 0; cls < input_data.GetLength(1); ++cls) |
|||
{ |
|||
sumProb += Mathf.Exp(input_data[batch, cls] - maxProb); |
|||
cdf[cls] = sumProb; |
|||
} |
|||
|
|||
// Generate the samples
|
|||
for (int sample = 0; sample < output_data.GetLength(1); ++sample) |
|||
{ |
|||
float p = (float)m_random.NextDouble() * sumProb; |
|||
int cls = 0; |
|||
while (cdf[cls] < p) |
|||
{ |
|||
++cls; |
|||
} |
|||
|
|||
output_data[batch, sample] = cls; |
|||
} |
|||
|
|||
} |
|||
|
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 5c9e297dad748408db9e5ce26b940fe3 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
|
|||
namespace MLAgents.InferenceBrain.Utils |
|||
{ |
|||
/// <summary>
|
|||
/// RandomNormal - A random number generator that produces normally distributed random numbers using the Marsaglia
|
|||
/// polar method (https://en.wikipedia.org/wiki/Marsaglia_polar_method)
|
|||
/// TODO: worth overriding System.Random instead of aggregating?
|
|||
/// </summary>
|
|||
public class RandomNormal |
|||
{ |
|||
private readonly double m_mean; |
|||
private readonly double m_stddev; |
|||
private readonly System.Random m_random; |
|||
|
|||
public RandomNormal(int seed, float mean = 0.0f, float stddev = 1.0f) |
|||
{ |
|||
m_mean = mean; |
|||
m_stddev = stddev; |
|||
m_random = new System.Random(seed); |
|||
} |
|||
|
|||
// Each iteration produces two numbers. Hold one here for next call
|
|||
private bool m_hasSpare = false; |
|||
private double m_spare = 0.0f; |
|||
|
|||
/// <summary>
|
|||
/// Return the next random double number
|
|||
/// </summary>
|
|||
/// <returns>Next random double number</returns>
|
|||
public double NextDouble() |
|||
{ |
|||
if (m_hasSpare) |
|||
{ |
|||
m_hasSpare = false; |
|||
return m_spare * m_stddev + m_mean; |
|||
} |
|||
|
|||
double u, v, s; |
|||
do |
|||
{ |
|||
u = m_random.NextDouble() * 2.0 - 1.0; |
|||
v = m_random.NextDouble() * 2.0 - 1.0; |
|||
s = u * u + v * v; |
|||
} while (s >= 1.0 || s == 0.0); |
|||
|
|||
s = Math.Sqrt(-2.0 * Math.Log(s) / 2); |
|||
m_spare = u * s; |
|||
m_hasSpare = true; |
|||
|
|||
return v * s * m_stddev + m_mean; |
|||
} |
|||
|
|||
private void IncreaseNextDim(Array arr, long[] indices) |
|||
{ |
|||
for (int i = 1; i < arr.Rank; ++i) |
|||
{ |
|||
++indices[i]; |
|||
if (i == arr.Rank - 1 || indices[i] < arr.GetLength(i)) |
|||
{ |
|||
break; |
|||
} |
|||
else |
|||
{ |
|||
indices[i] = 0; |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Fill a pre-allocated Tensor with random numbers
|
|||
/// </summary>
|
|||
/// <param name="t">The pre-allocated Tensor to fill</param>
|
|||
/// <exception cref="NotImplementedException">Throws when trying to fill a Tensor of type other than float</exception>
|
|||
/// <exception cref="ArgumentNullException">Throws when the Tensor is not allocated</exception>
|
|||
public void FillTensor(Tensor t) |
|||
{ |
|||
if (t.DataType != typeof(float)) |
|||
{ |
|||
throw new NotImplementedException("Random Normal does not support integer tensors yet!"); |
|||
} |
|||
|
|||
if (t.Data == null) |
|||
{ |
|||
throw new ArgumentNullException(); |
|||
} |
|||
|
|||
long[] indices = new long[t.Data.Rank]; |
|||
|
|||
// Since IEnumerable is const, and we don't know the dimentions of the Array
|
|||
// we need to traverse all the dimentions
|
|||
// TODO: this seems like a nice general operation for the Tensor, consider moving it there
|
|||
do |
|||
{ |
|||
t.Data.SetValue((float) NextDouble(), indices); |
|||
++indices[0]; |
|||
if (indices[0] == t.Data.GetLength(0)) |
|||
{ |
|||
indices[0] = 0; |
|||
IncreaseNextDim(t.Data, indices); |
|||
} |
|||
} while (indices[t.Data.Rank - 1] < t.Data.GetLength(t.Data.Rank - 1)); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: df8528cf20f0e4c64a4a7596eccc1631 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
撰写
预览
正在加载...
取消
保存
Reference in new issue