vincentpierre
6 年前
当前提交
5a44d3ee
共有 4 个文件被更改,包括 520 次插入 和 180 次删除
-
2UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelParamLoader.cs
-
352UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TFSharpInferenceEngine.cs
-
224UnitySDK/Assets/ML-Agents/Editor/Tests/MultinomialTest.cs
-
122UnitySDK/Assets/ML-Agents/Editor/Tests/RandomNormalTest.cs
|
|||
using System; |
|||
using NUnit.Framework; |
|||
using UnityEngine; |
|||
using MLAgents.InferenceBrain; |
|||
using MLAgents.InferenceBrain.Utils; |
|||
|
|||
namespace MLAgents.Tests |
|||
{ |
|||
public class MultinomialTest |
|||
{ |
|||
[Test] |
|||
public void TestEvalP() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
Data = new float[1, 3] {{0.1f, 0.2f, 0.7f}}, |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Tensor dst = new Tensor |
|||
{ |
|||
Data = new float[1, 3], |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
m.Eval(src, dst); |
|||
|
|||
float[] reference = {2, 2, 1}; |
|||
int i = 0; |
|||
foreach (var f in dst.Data) |
|||
{ |
|||
Assert.AreEqual(reference[i], f); |
|||
++i; |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestEvalLogits() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
Data = new float[1, 3] {{Mathf.Log(0.1f) - 50, Mathf.Log(0.2f) - 50, Mathf.Log(0.7f) - 50}}, |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Tensor dst = new Tensor |
|||
{ |
|||
Data = new float[1, 3], |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
m.Eval(src, dst); |
|||
|
|||
float[] reference = {2, 2, 2}; |
|||
int i = 0; |
|||
foreach (var f in dst.Data) |
|||
{ |
|||
Assert.AreEqual(reference[i], f); |
|||
++i; |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestEvalBatching() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
Data = new float[2, 3] |
|||
{ |
|||
{Mathf.Log(0.1f) - 50, Mathf.Log(0.2f) - 50, Mathf.Log(0.7f) - 50}, |
|||
{Mathf.Log(0.3f) - 25, Mathf.Log(0.4f) - 25, Mathf.Log(0.3f) - 25}, |
|||
|
|||
}, |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Tensor dst = new Tensor |
|||
{ |
|||
Data = new float[2, 3], |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
m.Eval(src, dst); |
|||
|
|||
float[] reference = {2, 2, 2, 0, 1, 0}; |
|||
int i = 0; |
|||
foreach (var f in dst.Data) |
|||
{ |
|||
Assert.AreEqual(reference[i], f); |
|||
++i; |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSrcInt() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.Integer |
|||
}; |
|||
|
|||
Assert.Throws<NotImplementedException>(() => m.Eval(src, null)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestDstInt() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.Integer |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSrcDataNull() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentNullException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestDstDataNull() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1] |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentNullException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSrcWrongShape() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1] |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1] |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestDstWrongShape() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1, 1] |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1] |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestUnequalBatchSize() |
|||
{ |
|||
Multinomial m = new Multinomial(2018); |
|||
|
|||
Tensor src = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[1, 1] |
|||
}; |
|||
Tensor dst = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = new float[2, 1] |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentException>(() => m.Eval(src, dst)); |
|||
} |
|||
|
|||
|
|||
} |
|||
} |
|
|||
using System; |
|||
using NUnit.Framework; |
|||
using MLAgents.InferenceBrain; |
|||
using MLAgents.InferenceBrain.Utils; |
|||
|
|||
namespace MLAgents.Tests |
|||
{ |
|||
|
|||
public class RandomNormalTest |
|||
{ |
|||
|
|||
[Test] |
|||
public void RandomNormalTestTwoDouble() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018); |
|||
|
|||
Assert.AreEqual(-0.46666, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(-0.37989, rn.NextDouble(), 0.0001); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestWithMean() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018, 5.0f); |
|||
|
|||
Assert.AreEqual(4.53333, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(4.6201, rn.NextDouble(), 0.0001); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestWithStddev() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018, 1.0f, 4.2f); |
|||
|
|||
Assert.AreEqual(-0.9599, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(-0.5955, rn.NextDouble(), 0.0001); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestWithMeanStddev() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018, -3.2f, 2.2f); |
|||
|
|||
Assert.AreEqual(-4.2266, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(-4.0357, rn.NextDouble(), 0.0001); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestTensorInt() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(1982); |
|||
Tensor t = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.Integer |
|||
}; |
|||
|
|||
Assert.Throws<NotImplementedException>(() => rn.FillTensor(t)); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestDataNull() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(1982); |
|||
Tensor t = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentNullException>(() => rn.FillTensor(t)); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestTensor() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(1982); |
|||
Tensor t = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = Array.CreateInstance(typeof(float), new long[3] {3, 4, 2}) |
|||
}; |
|||
|
|||
rn.FillTensor(t); |
|||
|
|||
float[] reference = new float[] |
|||
{ |
|||
-0.2139822f, |
|||
0.5051259f, |
|||
-0.5640336f, |
|||
-0.3357787f, |
|||
-0.2055894f, |
|||
-0.09432302f, |
|||
-0.01419199f, |
|||
0.53621f, |
|||
-0.5507085f, |
|||
-0.2651141f, |
|||
0.09315512f, |
|||
-0.04918706f, |
|||
-0.179625f, |
|||
0.2280539f, |
|||
0.1883962f, |
|||
0.4047216f, |
|||
0.1704049f, |
|||
0.5050544f, |
|||
-0.3365685f, |
|||
0.3542781f, |
|||
0.5951571f, |
|||
0.03460682f, |
|||
-0.5537263f, |
|||
-0.4378373f, |
|||
}; |
|||
|
|||
int i = 0; |
|||
foreach (float f in t.Data) |
|||
{ |
|||
Assert.AreEqual(f, reference[i], 0.0001); |
|||
++i; |
|||
} |
|||
|
|||
|
|||
} |
|||
} |
|||
} |
撰写
预览
正在加载...
取消
保存
Reference in new issue