using System; using Assert = UnityEngine.Assertions.Assert; using UnityEngine; namespace MLAgents.InferenceBrain.Utils { /// /// Multinomial - Draws samples from a multinomial distribution in log space /// Reference: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/multinomial_op.cc /// public class Multinomial { private readonly System.Random m_random; public Multinomial(int seed) { m_random = new System.Random(seed); } /// /// Draw samples from a multinomial distribution based on log-probabilities specified in tensor src. The samples /// will be saved in the dst tensor. /// /// 2-D tensor with shape batch_size x num_classes /// Allocated tensor with size batch_size x num_samples /// Multinomial doesn't support integer tensors /// Issue with tensor shape or type /// At least one of the tensors is not allocated public void Eval(Tensor src, Tensor dst) { if (src.DataType != typeof(float)) { throw new NotImplementedException("Multinomial does not support integer tensors yet!"); } if (src.ValueType != dst.ValueType) { throw new ArgumentException("Source and destination tensors have different types!"); } if (src.Data == null || dst.Data == null) { throw new ArgumentNullException(); } float[,] input_data = src.Data as float[,]; if (input_data == null) { throw new ArgumentException("Input data is not of the correct shape! Required batch x logits"); } float[,] output_data = dst.Data as float[,]; if (output_data == null) { throw new ArgumentException("Output data is not of the correct shape! Required batch x samples"); } if (input_data.GetLength(0) != output_data.GetLength(0)) { throw new ArgumentException("Batch size for input and output data is different!"); } float[] cdf = new float[input_data.GetLength(1)]; for (int batch = 0; batch < input_data.GetLength(0); ++batch) { // Find the class maximum float maxProb = float.NegativeInfinity; for (int cls = 0; cls < input_data.GetLength(1); ++cls) { maxProb = Mathf.Max(input_data[batch, cls], maxProb); } // Sum the log probabilities and compute CDF float sumProb = 0.0f; for (int cls = 0; cls < input_data.GetLength(1); ++cls) { sumProb += Mathf.Exp(input_data[batch, cls] - maxProb); cdf[cls] = sumProb; } // Generate the samples for (int sample = 0; sample < output_data.GetLength(1); ++sample) { float p = (float)m_random.NextDouble() * sumProb; int cls = 0; while (cdf[cls] < p) { ++cls; } output_data[batch, sample] = cls; } } } } }