您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
54 行
1.4 KiB
54 行
1.4 KiB
using NUnit.Framework;
|
|
using Unity.MLAgents.Inference.Utils;
|
|
|
|
namespace Unity.MLAgents.Tests
|
|
{
|
|
public class MultinomialTest
|
|
{
|
|
[Test]
|
|
public void TestDim1()
|
|
{
|
|
var m = new Multinomial(2018);
|
|
var cdf = new[] { 1f };
|
|
|
|
Assert.AreEqual(0, m.Sample(cdf));
|
|
Assert.AreEqual(0, m.Sample(cdf));
|
|
Assert.AreEqual(0, m.Sample(cdf));
|
|
}
|
|
|
|
[Test]
|
|
public void TestDim1Unscaled()
|
|
{
|
|
var m = new Multinomial(2018);
|
|
var cdf = new[] { 0.1f };
|
|
|
|
Assert.AreEqual(0, m.Sample(cdf));
|
|
Assert.AreEqual(0, m.Sample(cdf));
|
|
Assert.AreEqual(0, m.Sample(cdf));
|
|
}
|
|
|
|
[Test]
|
|
public void TestDim3()
|
|
{
|
|
var m = new Multinomial(2018);
|
|
var cdf = new[] { 0.1f, 0.3f, 1.0f };
|
|
|
|
Assert.AreEqual(2, m.Sample(cdf));
|
|
Assert.AreEqual(2, m.Sample(cdf));
|
|
Assert.AreEqual(2, m.Sample(cdf));
|
|
Assert.AreEqual(1, m.Sample(cdf));
|
|
}
|
|
|
|
[Test]
|
|
public void TestDim3Unscaled()
|
|
{
|
|
var m = new Multinomial(2018);
|
|
var cdf = new[] { 0.05f, 0.15f, 0.5f };
|
|
|
|
Assert.AreEqual(2, m.Sample(cdf));
|
|
Assert.AreEqual(2, m.Sample(cdf));
|
|
Assert.AreEqual(2, m.Sample(cdf));
|
|
Assert.AreEqual(1, m.Sample(cdf));
|
|
}
|
|
}
|
|
}
|