namespace Unity.MLAgents.Inference.Utils
{
///
/// Multinomial - Draws samples from a multinomial distribution given a (potentially unscaled)
/// cumulative mass function (CMF). This means that the CMF need not "end" with probability
/// mass of 1.0. For instance: [0.1, 0.2, 0.5] is a valid (unscaled). What is important is
/// that it is a cumulative function, not a probability function. In other words,
/// entry[i] = P(x \le i), NOT P(i - 1 \le x \lt i).
/// (\le stands for less than or equal to while \lt is strictly less than).
///
internal class Multinomial
{
readonly System.Random m_Random;
///
/// Constructor.
///
///
/// Seed for the random number generator used in the sampling process.
///
public Multinomial(int seed)
{
m_Random = new System.Random(seed);
}
///
/// Samples from the Multinomial distribution defined by the provided cumulative
/// mass function.
///
///
/// Cumulative mass function, which may be unscaled. The entries in this array need
/// to be monotonic (always increasing). If the CMF is scaled, then the last entry in
/// the array will be 1.0.
///
/// The number of possible branches, i.e. the effective size of the cmf array.
/// A sampled index from the CMF ranging from 0 to branchSize-1.
public int Sample(float[] cmf, int branchSize)
{
var p = (float)m_Random.NextDouble() * cmf[branchSize - 1];
var cls = 0;
while (cmf[cls] < p)
{
++cls;
}
return cls;
}
///
/// Samples from the Multinomial distribution defined by the provided cumulative
/// mass function.
///
/// A sampled index from the CMF ranging from 0 to cmf.Length-1.
public int Sample(float[] cmf)
{
return Sample(cmf, cmf.Length);
}
}
}