您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
91 行
2.9 KiB
91 行
2.9 KiB
using System;
|
|
using NUnit.Framework;
|
|
using Unity.MLAgents.Inference.Utils;
|
|
|
|
namespace Unity.MLAgents.Tests
|
|
{
|
|
public class RandomNormalTest
|
|
{
|
|
const float k_FirstValue = -1.19580f;
|
|
const float k_SecondValue = -0.97345f;
|
|
const double k_Epsilon = 0.0001;
|
|
|
|
[Test]
|
|
public void RandomNormalTestTwoDouble()
|
|
{
|
|
var rn = new RandomNormal(2018);
|
|
|
|
Assert.AreEqual(k_FirstValue, rn.NextDouble(), k_Epsilon);
|
|
Assert.AreEqual(k_SecondValue, rn.NextDouble(), k_Epsilon);
|
|
}
|
|
|
|
[Test]
|
|
public void RandomNormalTestWithMean()
|
|
{
|
|
var rn = new RandomNormal(2018, 5.0f);
|
|
|
|
Assert.AreEqual(k_FirstValue + 5.0, rn.NextDouble(), k_Epsilon);
|
|
Assert.AreEqual(k_SecondValue + 5.0, rn.NextDouble(), k_Epsilon);
|
|
}
|
|
|
|
[Test]
|
|
public void RandomNormalTestWithStddev()
|
|
{
|
|
var rn = new RandomNormal(2018, 0.0f, 4.2f);
|
|
|
|
Assert.AreEqual(k_FirstValue * 4.2, rn.NextDouble(), k_Epsilon);
|
|
Assert.AreEqual(k_SecondValue * 4.2, rn.NextDouble(), k_Epsilon);
|
|
}
|
|
|
|
[Test]
|
|
public void RandomNormalTestWithMeanStddev()
|
|
{
|
|
const float mean = -3.2f;
|
|
const float stddev = 2.2f;
|
|
var rn = new RandomNormal(2018, mean, stddev);
|
|
|
|
Assert.AreEqual(k_FirstValue * stddev + mean, rn.NextDouble(), k_Epsilon);
|
|
Assert.AreEqual(k_SecondValue * stddev + mean, rn.NextDouble(), k_Epsilon);
|
|
}
|
|
|
|
[Test]
|
|
public void RandomNormalTestDistribution()
|
|
{
|
|
const float mean = -3.2f;
|
|
const float stddev = 2.2f;
|
|
var rn = new RandomNormal(2018, mean, stddev);
|
|
|
|
const int numSamples = 100000;
|
|
// Adapted from https://www.johndcook.com/blog/standard_deviation/
|
|
// Computes stddev and mean without losing precision
|
|
double oldM = 0.0, newM = 0.0, oldS = 0.0, newS = 0.0;
|
|
|
|
for (var i = 0; i < numSamples; i++)
|
|
{
|
|
var x = rn.NextDouble();
|
|
if (i == 0)
|
|
{
|
|
oldM = newM = x;
|
|
oldS = 0.0;
|
|
}
|
|
else
|
|
{
|
|
newM = oldM + (x - oldM) / i;
|
|
newS = oldS + (x - oldM) * (x - newM);
|
|
|
|
// set up for next iteration
|
|
oldM = newM;
|
|
oldS = newS;
|
|
}
|
|
}
|
|
|
|
var sampleMean = newM;
|
|
var sampleVariance = newS / (numSamples - 1);
|
|
var sampleStddev = Math.Sqrt(sampleVariance);
|
|
|
|
// Note a larger epsilon here. We could get closer to the true values with more samples.
|
|
Assert.AreEqual(mean, sampleMean, 0.01);
|
|
Assert.AreEqual(stddev, sampleStddev, 0.01);
|
|
}
|
|
}
|
|
}
|