浏览代码
fix bug in RandomNormal (#2294)
fix bug in RandomNormal (#2294)
* fix bug in RandomNormal, add test for distribution * extract epsilon, rename vars/develop-generalizationTraining-TrainerController
GitHub
6 年前
当前提交
b11efed9
共有 2 个文件被更改,包括 145 次插入 和 102 次删除
-
235UnitySDK/Assets/ML-Agents/Editor/Tests/RandomNormalTest.cs
-
12UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/Utils/RandomNormal.cs
|
|||
using System; |
|||
using System; |
|||
using UnityEngine; |
|||
using System.Collections; |
|||
|
|||
|
|||
|
|||
public class RandomNormalTest |
|||
{ |
|||
public class RandomNormalTest |
|||
{ |
|||
private const float firstValue = -1.19580f; |
|||
private const float secondValue = -0.97345f; |
|||
private const double epsilon = 0.0001; |
|||
[Test] |
|||
public void RandomNormalTestTwoDouble() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018); |
|||
[Test] |
|||
public void RandomNormalTestTwoDouble () |
|||
{ |
|||
RandomNormal rn = new RandomNormal (2018); |
|||
Assert.AreEqual(-0.46666, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(-0.37989, rn.NextDouble(), 0.0001); |
|||
} |
|||
Assert.AreEqual (firstValue, rn.NextDouble (), epsilon); |
|||
Assert.AreEqual (secondValue, rn.NextDouble (), epsilon); |
|||
} |
|||
[Test] |
|||
public void RandomNormalTestWithMean() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018, 5.0f); |
|||
[Test] |
|||
public void RandomNormalTestWithMean () |
|||
{ |
|||
RandomNormal rn = new RandomNormal (2018, 5.0f); |
|||
Assert.AreEqual(4.53333, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(4.6201, rn.NextDouble(), 0.0001); |
|||
} |
|||
Assert.AreEqual (firstValue + 5.0, rn.NextDouble (), epsilon); |
|||
Assert.AreEqual (secondValue + 5.0, rn.NextDouble (), epsilon); |
|||
} |
|||
[Test] |
|||
public void RandomNormalTestWithStddev() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018, 1.0f, 4.2f); |
|||
[Test] |
|||
public void RandomNormalTestWithStddev () |
|||
{ |
|||
RandomNormal rn = new RandomNormal (2018, 0.0f, 4.2f); |
|||
Assert.AreEqual(-0.9599, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(-0.5955, rn.NextDouble(), 0.0001); |
|||
} |
|||
Assert.AreEqual (firstValue * 4.2, rn.NextDouble (), epsilon); |
|||
Assert.AreEqual (secondValue * 4.2, rn.NextDouble (), epsilon); |
|||
} |
|||
[Test] |
|||
public void RandomNormalTestWithMeanStddev() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(2018, -3.2f, 2.2f); |
|||
[Test] |
|||
public void RandomNormalTestWithMeanStddev () |
|||
{ |
|||
float mean = -3.2f; |
|||
float stddev = 2.2f; |
|||
RandomNormal rn = new RandomNormal (2018, mean, stddev); |
|||
|
|||
Assert.AreEqual (firstValue * stddev + mean, rn.NextDouble (), epsilon); |
|||
Assert.AreEqual (secondValue * stddev + mean, rn.NextDouble (), epsilon); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestTensorInt () |
|||
{ |
|||
RandomNormal rn = new RandomNormal (1982); |
|||
Tensor t = new Tensor { |
|||
ValueType = Tensor.TensorType.Integer |
|||
}; |
|||
|
|||
Assert.Throws<NotImplementedException> (() => rn.FillTensor (t)); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestDataNull () |
|||
{ |
|||
RandomNormal rn = new RandomNormal (1982); |
|||
Tensor t = new Tensor { |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
|
|||
Assert.Throws<ArgumentNullException> (() => rn.FillTensor (t)); |
|||
} |
|||
|
|||
[Test] |
|||
public void RandomNormalTestDistribution () |
|||
{ |
|||
float mean = -3.2f; |
|||
float stddev = 2.2f; |
|||
RandomNormal rn = new RandomNormal (2018, mean, stddev); |
|||
|
|||
int numSamples = 100000; |
|||
// Adapted from https://www.johndcook.com/blog/standard_deviation/
|
|||
// Computes stddev and mean without losing precision
|
|||
double oldM = 0.0, newM = 0.0, oldS = 0.0, newS = 0.0; |
|||
Assert.AreEqual(-4.2266, rn.NextDouble(), 0.0001); |
|||
Assert.AreEqual(-4.0357, rn.NextDouble(), 0.0001); |
|||
} |
|||
for (int i = 0; i < numSamples; i++) { |
|||
double x = rn.NextDouble (); |
|||
if (i == 0) { |
|||
oldM = newM = x; |
|||
oldS = 0.0; |
|||
} else { |
|||
newM = oldM + (x - oldM) / i; |
|||
newS = oldS + (x - oldM) * (x - newM); |
|||
[Test] |
|||
public void RandomNormalTestTensorInt() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(1982); |
|||
Tensor t = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.Integer |
|||
}; |
|||
// set up for next iteration
|
|||
oldM = newM; |
|||
oldS = newS; |
|||
} |
|||
} |
|||
Assert.Throws<NotImplementedException>(() => rn.FillTensor(t)); |
|||
} |
|||
double sampleMean = newM; |
|||
double sampleVariance = newS / (numSamples - 1); |
|||
double sampleStddev = Math.Sqrt (sampleVariance); |
|||
[Test] |
|||
public void RandomNormalTestDataNull() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(1982); |
|||
Tensor t = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint |
|||
}; |
|||
// Note a larger epsilon here. We could get closer to the true values with more samples.
|
|||
Assert.AreEqual (mean, sampleMean, 0.01); |
|||
Assert.AreEqual (stddev, sampleStddev, 0.01); |
|||
Assert.Throws<ArgumentNullException>(() => rn.FillTensor(t)); |
|||
} |
|||
} |
|||
[Test] |
|||
public void RandomNormalTestTensor() |
|||
{ |
|||
RandomNormal rn = new RandomNormal(1982); |
|||
Tensor t = new Tensor |
|||
{ |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = Array.CreateInstance(typeof(float), new long[3] {3, 4, 2}) |
|||
}; |
|||
[Test] |
|||
public void RandomNormalTestTensor () |
|||
{ |
|||
RandomNormal rn = new RandomNormal (1982); |
|||
Tensor t = new Tensor { |
|||
ValueType = Tensor.TensorType.FloatingPoint, |
|||
Data = Array.CreateInstance (typeof (float), new long [3] { 3, 4, 2 }) |
|||
}; |
|||
rn.FillTensor(t); |
|||
rn.FillTensor (t); |
|||
float[] reference = new float[] |
|||
{ |
|||
-0.2139822f, |
|||
0.5051259f, |
|||
-0.5640336f, |
|||
-0.3357787f, |
|||
-0.2055894f, |
|||
-0.09432302f, |
|||
-0.01419199f, |
|||
0.53621f, |
|||
-0.5507085f, |
|||
-0.2651141f, |
|||
0.09315512f, |
|||
-0.04918706f, |
|||
-0.179625f, |
|||
0.2280539f, |
|||
0.1883962f, |
|||
0.4047216f, |
|||
0.1704049f, |
|||
0.5050544f, |
|||
-0.3365685f, |
|||
0.3542781f, |
|||
0.5951571f, |
|||
0.03460682f, |
|||
-0.5537263f, |
|||
-0.4378373f, |
|||
}; |
|||
float [] reference = new float [] |
|||
{ |
|||
-0.4315872f, |
|||
0.9561074f, |
|||
-1.130287f, |
|||
-0.7763879f, |
|||
-0.3027347f, |
|||
-0.1377991f, |
|||
-0.02921959f, |
|||
0.9520947f, |
|||
-1.11074f, |
|||
-0.5018106f, |
|||
0.1413168f, |
|||
-0.07491868f, |
|||
-0.2645015f, |
|||
0.3331701f, |
|||
0.3716498f, |
|||
1.088157f, |
|||
0.3414804f, |
|||
1.167787f, |
|||
-0.5105762f, |
|||
0.5396146f, |
|||
1.225356f, |
|||
0.06144788f, |
|||
-1.092338f, |
|||
-1.177194f, |
|||
}; |
|||
int i = 0; |
|||
foreach (float f in t.Data) |
|||
{ |
|||
Assert.AreEqual(f, reference[i], 0.0001); |
|||
++i; |
|||
} |
|||
int i = 0; |
|||
foreach (float f in t.Data) { |
|||
Assert.AreEqual (f, reference [i], epsilon); |
|||
++i; |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
撰写
预览
正在加载...
取消
保存
Reference in new issue