浏览代码
Merge pull request #1451 from Unity-Technologies/release-v0.6-revertTF1
Merge pull request #1451 from Unity-Technologies/release-v0.6-revertTF1
Release v0.6 revert tf1/develop-generalizationTraining-TrainerController
GitHub
6 年前
当前提交
e9121bb5
共有 108 个文件被更改,包括 1283 次插入 和 261 次删除
-
2UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs
-
2UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
-
2UnitySDK/Assets/ML-Agents/Examples/3DBall/Brains/3DBallHardLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/3DBall/Brains/3DBallLearning.asset
-
10UnitySDK/Assets/ML-Agents/Examples/BananaCollectors/Brains/BananaLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/Basic/Brains/BasicLearning.asset
-
10UnitySDK/Assets/ML-Agents/Examples/Bouncer/Brains/BouncerLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/Crawler/Brains/CrawlerDynamicLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/Crawler/Brains/CrawlerStaticLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/GridWorld/Brains/GridWorldLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/Hallway/Brains/HallwayLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/PushBlock/Brains/PushBlockLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/Pyramids/Brains/PyramidsLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/Reacher/Brains/ReacherLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/Soccer/Brains/GoalieLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/Soccer/Brains/StrikerLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/Tennis/Brains/TennisLearning.asset
-
10UnitySDK/Assets/ML-Agents/Examples/Walker/Brains/WalkerLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/WallJump/Brains/BigWallJumpLearning.asset
-
2UnitySDK/Assets/ML-Agents/Examples/WallJump/Brains/SmallWallJumpLearning.asset
-
6UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs
-
5UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
-
12UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelParamLoader.cs
-
5UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs
-
6UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs
-
36UnitySDK/Assets/ML-Agents/Scripts/LearningBrain.cs
-
19docs/Background-TensorFlow.md
-
38docs/Basic-Guide.md
-
14docs/FAQ.md
-
14docs/Getting-Started-with-Balance-Ball.md
-
10docs/Learning-Environment-Design-Learning-Brains.md
-
4docs/Learning-Environment-Executable.md
-
4docs/ML-Agents-Overview.md
-
12docs/Migrating.md
-
2docs/Readme.md
-
2docs/Training-Imitation-Learning.md
-
2docs/Training-ML-Agents.md
-
2ml-agents/mlagents/trainers/policy.py
-
224UnitySDK/Assets/ML-Agents/Editor/Tests/MultinomialTest.cs
-
122UnitySDK/Assets/ML-Agents/Editor/Tests/RandomNormalTest.cs
-
7UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBallHardLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBallLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/BananaCollectors/TFModels/BananaLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Basic/TFModels/BasicLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Bouncer/TFModels/BouncerLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamicLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStaticLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorldLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Hallway/TFModels/HallwayLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/PushBlock/TFModels/PushBlockLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Pyramids/TFModels/PyramidsLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Reacher/TFModels/ReacherLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Soccer/TFModels/GoalieLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Soccer/TFModels/StrikerLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Tennis/TFModels/TennisLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Walker/TFModels/WalkerLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/BigWallJumpLearning.bytes.meta
-
7UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/SmallWallJumpLearning.bytes.meta
-
213UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TFSharpInferenceEngine.cs
-
12UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TFSharpInferenceEngine.cs.meta
-
40UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Tensor.cs
-
11UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Tensor.cs.meta
-
8UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils.meta
-
55docs/Using-TensorFlow-Sharp-in-Unity.md
-
117docs/images/imported-tensorflowsharp.png
-
98UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/Multinomial.cs
-
11UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/Multinomial.cs.meta
-
105UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/RandomNormal.cs
-
11UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/RandomNormal.cs.meta
-
7UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBallLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBallHardLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/BananaCollectors/TFModels/BananaLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Basic/TFModels/BasicLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Bouncer/TFModels/BouncerLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamicLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStaticLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorldLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Hallway/TFModels/HallwayLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/PushBlock/TFModels/PushBlockLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Pyramids/TFModels/PyramidsLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Reacher/TFModels/ReacherLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Soccer/TFModels/GoalieLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Soccer/TFModels/StrikerLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Tennis/TFModels/TennisLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/Walker/TFModels/WalkerLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/BigWallJumpLearning.tf.meta
-
7UnitySDK/Assets/ML-Agents/Examples/WallJump/TFModels/SmallWallJumpLearning.tf.meta
-
8UnitySDK/Assets/ML-Agents/Plugins/InferenceEngine.meta
-
0/UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBallLearning.bytes
-
0/UnitySDK/Assets/ML-Agents/Examples/3DBall/TFModels/3DBallHardLearning.bytes
-
0/UnitySDK/Assets/ML-Agents/Examples/BananaCollectors/TFModels/BananaLearning.bytes
-
0/UnitySDK/Assets/ML-Agents/Examples/Basic/TFModels/BasicLearning.bytes
-
0/UnitySDK/Assets/ML-Agents/Examples/Bouncer/TFModels/BouncerLearning.bytes
-
0/UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStaticLearning.bytes
-
0/UnitySDK/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamicLearning.bytes
-
0/UnitySDK/Assets/ML-Agents/Examples/GridWorld/TFModels/GridWorldLearning.bytes
-
0/UnitySDK/Assets/ML-Agents/Examples/Hallway/TFModels/HallwayLearning.bytes
-
0/UnitySDK/Assets/ML-Agents/Examples/PushBlock/TFModels/PushBlockLearning.bytes
|
|||
using System; |
|||
using NUnit.Framework; |
|||
using UnityEngine; |
|||
using MLAgents.InferenceBrain; |
|||
using MLAgents.InferenceBrain.Utils; |
|||
|
|||
namespace MLAgents.Tests |
|||
{ |
|||
public class MultinomialTest |
|||
{ |
|||
[Test] |
|||
public void TestEvalP() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
Data = new float[1, 3] {{0.1f, 0.2f, 0.7f}}, |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Tensor dst = new Tensor |
|||
{ |
|||
Data = new float[1, 3], |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
m.Eval(src, dst); |
|||
|
|||
float[] reference = {2, 2, 1}; |
|||
int i = 0; |
|||
foreach (var f in dst.Data) |
|||
{ |
|||
Assert.AreEqual(reference[i], f); |
|||
++i; |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestEvalLogits() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
Data = new float[1, 3] {{Mathf.Log(0.1f) - 50, Mathf.Log(0.2f) - 50, Mathf.Log(0.7f) - 50}}, |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Tensor dst = new Tensor |
|||
{ |
|||
Data = new float[1, 3], |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
m.Eval(src, dst); |
|||
|
|||
float[] reference = {2, 2, 2}; |
|||
int i = 0; |
|||
foreach (var f in dst.Data) |
|||
{ |
|||
Assert.AreEqual(reference[i], f); |
|||
++i; |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestEvalBatching() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
Data = new float[2, 3] |
|||
{ |
|||
{Mathf.Log(0.1f) - 50, Mathf.Log(0.2f) - 50, Mathf.Log(0.7f) - 50}, |
|||
{Mathf.Log(0.3f) - 25, Mathf.Log(0.4f) - 25, Mathf.Log(0.3f) - 25}, |
|||
|
|||
}, |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Tensor dst = new Tensor |
|||
{ |
|||
Data = new float[2, 3], |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
m.Eval(src, dst); |
|||
|
|||
float[] reference = {2, 2, 2, 0, 1, 0}; |
|||
int i = 0; |
|||
foreach (var f in dst.Data) |
|||
{ |
|||
Assert.AreEqual(reference[i], f); |
|||
++i; |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSrcInt() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.Integer |
|||
}; |
|||
|
|||
Assert.Throws<NotImplementedException>(() => m.Eval(src, null)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestDstInt() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.Integer |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSrcDataNull() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentNullException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestDstDataNull() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1] |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentNullException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSrcWrongShape() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1] |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1] |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestDstWrongShape() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1, 1] |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1] |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestUnequalBatchSize() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1, 1] |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[2, 1] |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
|
|||
} |
|||
} |
|
|||
using System; |
|||
using NUnit.Framework; |
|||
using MLAgents.InferenceBrain; |
|||
using MLAgents.InferenceBrain.Utils; |
|||
|
|||
namespace MLAgents.Tests |
|||
{ |
|||
|
|||
public class RandomNormalTest |
|||
{ |
|||
|
|||
[Test] |
|||
public void RandomNormalTestTwoDouble() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018); |
|||
|
|||
Assert.AreEqual(-0.46666, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(-0.37989, rn.NextDouble(), 0.0001); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestWithMean() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018, 5.0f); |
|||
|
|||
Assert.AreEqual(4.53333, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(4.6201, rn.NextDouble(), 0.0001); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestWithStddev() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018, 1.0f, 4.2f); |
|||
|
|||
Assert.AreEqual(-0.9599, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(-0.5955, rn.NextDouble(), 0.0001); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestWithMeanStddev() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018, -3.2f, 2.2f); |
|||
|
|||
Assert.AreEqual(-4.2266, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(-4.0357, rn.NextDouble(), 0.0001); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestTensorInt() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(1982); |
|||
Tensor t = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.Integer |
|||
}; |
|||
|
|||
Assert.Throws<NotImplementedException>(() => rn.FillTensor(t)); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestDataNull() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(1982); |
|||
Tensor t = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentNullException>(() => rn.FillTensor(t)); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestTensor() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(1982); |
|||
Tensor t = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = Array.CreateInstance(typeof(float), new long[3] {3, 4, 2}) |
|||
}; |
|||
|
|||
rn.FillTensor(t); |
|||
|
|||
float[] reference = new float[] |
|||
{ |
|||
-0.2139822f, |
|||
0.5051259f, |
|||
-0.5640336f, |
|||
-0.3357787f, |
|||
-0.2055894f, |
|||
-0.09432302f, |
|||
-0.01419199f, |
|||
0.53621f, |
|||
-0.5507085f, |
|||
-0.2651141f, |
|||
0.09315512f, |
|||
-0.04918706f, |
|||
-0.179625f, |
|||
0.2280539f, |
|||
0.1883962f, |
|||
0.4047216f, |
|||
0.1704049f, |
|||
0.5050544f, |
|||
-0.3365685f, |
|||
0.3542781f, |
|||
0.5951571f, |
|||
0.03460682f, |
|||
-0.5537263f, |
|||
-0.4378373f, |
|||
}; |
|||
|
|||
int i = 0; |
|||
foreach (float f in t.Data) |
|||
{ |
|||
Assert.AreEqual(f, reference[i], 0.0001); |
|||
++i; |
|||
} |
|||
|
|||
|
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 8a2da2218425f46e9921caefda4b7813 |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 9f58800fa9d54477aa01ee258842f6b3 |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 69bd818d72b944849916d2fda9fe471b |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 503ce1e8257904bd0b5be8f7fb4b5d28 |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 760d2b8347b4b46e3a44d9b989e1304e |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 9482a8782450a4d87b20942c4523176b |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: e256bd37f98f246e5be72618766d0a93 |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 0fd168a0ea1d04ef9a68c80cf452ce3d |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 84588668e6ea948d3ab55bb813cc769b |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: e22850d2072904a0ab06069cda2599e5 |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 7d1c7f27447234c3a81169de00dcaa8a |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 5fb4a3624e9ca4e1c81b51b5117cb31e |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 890ab8f03425c4a80a52ba674ddec3f3 |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 23410257d39d44616bfefdff59c7fbc9 |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 6d4281b70d41f48cb83d663b84f78c9a |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 48ab33cf9fbee4883948187618027835 |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: c118879bb5db84f269e4da23ba8c4f61 |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 92cd96b2c34334db692e93af25b64d2a |
|||
TextScriptImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
#if ENABLE_TENSORFLOW
|
|||
using System.Collections.Generic; |
|||
using TensorFlow; |
|||
using System.Linq; |
|||
using System; |
|||
using UnityEngine.Profiling; |
|||
using System.Runtime.InteropServices; |
|||
using UnityEngine; |
|||
|
|||
namespace MLAgents.InferenceBrain |
|||
{ |
|||
/// <summary>
|
|||
/// TFSharpInferenceEngine - Inference engine utilizing the TensorFlow Sharp package to run inference
|
|||
/// on frozen TensorFlow models
|
|||
/// </summary>
|
|||
public class TFSharpInferenceEngine |
|||
{ |
|||
private TFGraph m_graph; |
|||
private TFSession m_session; |
|||
|
|||
public void PrepareModel(byte[] model) |
|||
{ |
|||
Profiler.BeginSample("TFSharpInferenceComponent.PrepareModel"); |
|||
|
|||
#if UNITY_ANDROID && !UNITY_EDITOR
|
|||
// This needs to ba called only once and will raise an exception if called multiple times
|
|||
try{ |
|||
TensorFlowSharp.Android.NativeBinding.Init(); |
|||
} |
|||
catch{ |
|||
|
|||
} |
|||
#endif
|
|||
m_graph = new TFGraph(); |
|||
m_graph.Import(model); |
|||
m_session = new TFSession(m_graph); |
|||
Profiler.EndSample(); |
|||
} |
|||
|
|||
public int ExecuteGraph(IEnumerable<Tensor> inputs_it, IEnumerable<Tensor> outputs_it) |
|||
{ |
|||
Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph"); |
|||
Tensor[] inputs = inputs_it.ToArray(); |
|||
Tensor[] outputs = outputs_it.ToArray(); |
|||
|
|||
// TODO: Can/should we pre-allocate that?
|
|||
TFSession.Runner runner = m_session.GetRunner(); |
|||
|
|||
inputs.ToList().ForEach((Tensor input) => |
|||
{ |
|||
if (input.Shape.Length == 0) |
|||
{ |
|||
var data = input.Data.GetValue(0); |
|||
if (input.DataType == typeof(int)) |
|||
{ |
|||
runner.AddInput(m_graph[input.Name][0], (int)data); |
|||
} |
|||
else |
|||
{ |
|||
runner.AddInput(m_graph[input.Name][0], (float)data); |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
runner.AddInput(m_graph[input.Name][0], input.Data); |
|||
} |
|||
}); |
|||
|
|||
// TODO: better way to pre-allocate this?
|
|||
outputs.ToList().ForEach(s => runner.Fetch(s.Name)); |
|||
|
|||
TFStatus status = new TFStatus(); |
|||
Profiler.BeginSample("TFSharpInferenceComponent.ExecuteGraph.RunnerRun"); |
|||
var out_tensors = runner.Run(status); |
|||
Profiler.EndSample(); |
|||
|
|||
if (!status.Ok) |
|||
{ |
|||
Debug.LogError(status.StatusMessage); |
|||
return -1; |
|||
} |
|||
|
|||
Debug.Assert(outputs.Length == out_tensors.Length); |
|||
|
|||
for (var i = 0; i < outputs.Length; ++i) |
|||
{ |
|||
if (outputs[i].Shape.Length == 0) |
|||
{ |
|||
// Handle scalars
|
|||
outputs[i].Data = Array.CreateInstance(outputs[i].DataType, new long[1] {1}); |
|||
outputs[i].Data.SetValue(out_tensors[i].GetValue(), 0); |
|||
} |
|||
else |
|||
{ |
|||
outputs[i].Data = out_tensors[i].GetValue() as Array; |
|||
} |
|||
} |
|||
|
|||
Profiler.EndSample(); |
|||
// TODO: create error codes
|
|||
return 0; |
|||
} |
|||
|
|||
[DllImport("libtensorflow")] |
|||
private static extern unsafe void TF_OperationGetAttrType(IntPtr oper, string attr_name, |
|||
TFDataType* value, IntPtr status); |
|||
|
|||
[DllImport("libtensorflow")] |
|||
private static extern unsafe void TF_OperationGetAttrShape(IntPtr oper, string attr_name, long[] value, |
|||
int num_dims, IntPtr status); |
|||
|
|||
private Tensor GetOpMetadata(TFOperation op) |
|||
{ |
|||
TFStatus status = new TFStatus(); |
|||
|
|||
// Query the shape
|
|||
long[] shape = null; |
|||
var shape_attr = op.GetAttributeMetadata("shape", status); |
|||
if (!status.Ok || shape_attr.TotalSize <= 0) |
|||
{ |
|||
Debug.LogWarning("Operation " + op.Name + " does not contain shape attribute or it" + |
|||
" doesn't contain valid shape data!"); |
|||
} |
|||
else |
|||
{ |
|||
if (shape_attr.IsList) |
|||
{ |
|||
throw new NotImplementedException("Querying lists is not implemented yet!"); |
|||
} |
|||
else |
|||
{ |
|||
TFStatus s = new TFStatus(); |
|||
long[] dims = new long[shape_attr.TotalSize]; |
|||
TF_OperationGetAttrShape(op.Handle, "shape", dims, (int)shape_attr.TotalSize, |
|||
s.Handle); |
|||
if (!status.Ok) |
|||
{ |
|||
throw new FormatException("Could not query model for op shape (" + op.Name + ")"); |
|||
} |
|||
else |
|||
{ |
|||
shape = new long[dims.Length]; |
|||
for (int i = 0; i < shape_attr.TotalSize; ++i) |
|||
{ |
|||
if (dims[i] == -1) |
|||
{ |
|||
// we have to use batchsize 1
|
|||
shape[i] = 1; |
|||
} |
|||
else |
|||
{ |
|||
shape[i] = dims[i]; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
// Query the data type
|
|||
TFDataType type_value = new TFDataType(); |
|||
unsafe |
|||
{ |
|||
TFStatus s = new TFStatus(); |
|||
TF_OperationGetAttrType(op.Handle, "dtype", &type_value, s.Handle); |
|||
if (!s.Ok) |
|||
{ |
|||
Debug.LogWarning("Operation " + op.Name + |
|||
": error retrieving dtype, assuming float!"); |
|||
type_value = TFDataType.Float; |
|||
} |
|||
} |
|||
|
|||
Tensor.TensorType placeholder_type = Tensor.TensorType.FloatingPoint; |
|||
switch (type_value) |
|||
{ |
|||
case TFDataType.Float: |
|||
placeholder_type = Tensor.TensorType.FloatingPoint; |
|||
break; |
|||
case TFDataType.Int32: |
|||
placeholder_type = Tensor.TensorType.Integer; |
|||
break; |
|||
default: |
|||
Debug.LogWarning("Operation " + op.Name + |
|||
" is not a float/integer. Proceed at your own risk!"); |
|||
break; |
|||
} |
|||
|
|||
Tensor t = new Tensor |
|||
{ |
|||
Data = null, |
|||
Name = op.Name, |
|||
Shape = shape, |
|||
ValueType = placeholder_type |
|||
}; |
|||
return t; |
|||
} |
|||
|
|||
public IEnumerable<Tensor> InputFeatures() |
|||
{ |
|||
List<Tensor> inputs = new List<Tensor>(); |
|||
foreach (var op in m_graph.GetEnumerator()) |
|||
{ |
|||
if (op.OpType == "Placeholder") |
|||
{ |
|||
inputs.Add(GetOpMetadata(op)); |
|||
} |
|||
} |
|||
|
|||
return inputs; |
|||
} |
|||
} |
|||
} |
|||
#endif
|
|
|||
fileFormatVersion: 2 |
|||
guid: 120cbe3fa702f4e428f57ae1d893a0a7 |
|||
timeCreated: 1535148728 |
|||
licenseType: Free |
|||
MonoImporter: |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
|
|||
namespace MLAgents.InferenceBrain |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// Tensor - A class to encapsulate a Tensor used for inference.
|
|||
///
|
|||
/// This class contains the Array that holds the data array, the shapes, type and the placeholder in the
|
|||
/// execution graph. All the fields are editable in the inspector, allowing the user to specify everything
|
|||
/// but the data in a graphical way.
|
|||
/// </summary>
|
|||
[System.Serializable] |
|||
public class Tensor |
|||
{ |
|||
public enum TensorType |
|||
{ |
|||
Integer, |
|||
FloatingPoint |
|||
}; |
|||
|
|||
private static Dictionary<TensorType, Type> m_typeMap = new Dictionary<TensorType, Type>() |
|||
{ |
|||
{ TensorType.FloatingPoint, typeof(float)}, |
|||
{TensorType.Integer, typeof(int)} |
|||
}; |
|||
|
|||
public string Name; |
|||
public TensorType ValueType; |
|||
// Since Type is not serializable, we use the DisplayType for the Inspector
|
|||
public Type DataType |
|||
{ |
|||
get { return m_typeMap[ValueType]; } |
|||
} |
|||
public long[] Shape; |
|||
public Array Data; |
|||
} |
|||
|
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 926149e757bc849689e00e12d8c6fbdb |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 7872e22895343467b9fe96d336a7edba |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
# Using TensorFlowSharp in Unity (Experimental) |
|||
|
|||
The ML-Agents toolkit allows you to use pre-trained |
|||
[TensorFlow graphs](https://www.tensorflow.org/programmers_guide/graphs) |
|||
inside your Unity |
|||
games. This support is possible thanks to the |
|||
[TensorFlowSharp project](https://github.com/migueldeicaza/TensorFlowSharp). |
|||
The primary purpose for this support is to use the TensorFlow models produced by |
|||
the ML-Agents toolkit's own training programs, but a side benefit is that you |
|||
can use any TensorFlow model. |
|||
|
|||
_Notice: This feature is still experimental. While it is possible to embed |
|||
trained models into Unity games, Unity Technologies does not officially support |
|||
this use-case for production games at this time. As such, no guarantees are |
|||
provided regarding the quality of experience. If you encounter issues regarding |
|||
battery life, or general performance (especially on mobile), please let us |
|||
know._ |
|||
|
|||
## Supported devices |
|||
|
|||
* Linux 64 bits |
|||
* Mac OS X 64 bits |
|||
* Windows 64 bits |
|||
* iOS (Requires additional steps) |
|||
* Android |
|||
|
|||
## Requirements |
|||
|
|||
* Unity 2017.4 or above |
|||
* Unity TensorFlow Plugin ([Download here](https://s3.amazonaws.com/unity-ml-agents/0.5/TFSharpPlugin.unitypackage)) |
|||
|
|||
## Using TensorFlowSharp with ML-Agents |
|||
|
|||
Go to `Edit` -> `Player Settings` and add `ENABLE_TENSORFLOW` to the `Scripting |
|||
Define Symbols` for each type of device you want to use (**`PC, Mac and Linux |
|||