using System;
using System.Collections.Generic;
using UnityEngine.Experimental.Perception.Randomization.Samplers;
namespace UnityEngine.Experimental.Perception.Randomization.Parameters
/// Generates samples by choosing one option from a list of choices
/// The sample type of the categorical parameter
public abstract class CategoricalParameter : CategoricalParameterBase
[SerializeField] internal bool uniform = true;
[SerializeReference] ISampler m_Sampler = new UniformSampler(0f, 1f);
[SerializeField] List m_Categories = new List();
float[] m_NormalizedProbabilities;
/// Returns an IEnumerable that iterates over each sampler field in this parameter
internal override IEnumerable samplers
get { yield return m_Sampler; }
/// The sample type generated by this parameter
public sealed override Type sampleType => typeof(T);
/// Returns the number of stored categories
/// The number of stored categories
public int GetCategoryCount() => m_Categories.Count;
/// Returns the category stored at the specified index
/// The index of the category to lookup
/// The category stored at the specified index
public T GetCategory(int index) => m_Categories[index];
/// Returns the probability value stored at the specified index
/// The index of the probability value to lookup
/// The probability value stored at the specified index
public float GetProbability(int index) => probabilities[index];
/// Updates this parameter's list of categorical options
/// The categorical options to configure
public void SetOptions(IEnumerable categoricalOptions)
foreach (var category in categoricalOptions)
AddOption(category, 1f);
/// Updates this parameter's list of categorical options
/// The categorical options to configure
public void SetOptions(IEnumerable<(T, float)> categoricalOptions)
foreach (var (category, probability) in categoricalOptions)
AddOption(category, probability);
void AddOption(T option, float probability)
/// Returns a list of the potential categories this parameter can generate
public IReadOnlyList<(T, float)> categories
var catOptions = new List<(T, float)>(m_Categories.Count);
for (var i = 0; i < m_Categories.Count; i++)
catOptions.Add((m_Categories[i], probabilities[i]));
return catOptions;
/// Validates the categorical probabilities assigned to this parameter
public override void Validate()
// Check for a non-zero amount of specified categories
if (m_Categories.Count == 0)
throw new ParameterValidationException("No options added to categorical parameter");
// Check for duplicate categories
var uniqueCategories = new HashSet();
foreach (var option in m_Categories)
if (uniqueCategories.Contains(option))
throw new ParameterValidationException("Duplicate categories");
// Check if the number of specified probabilities is different from the number of listed categories
if (!uniform)
if (probabilities.Count != m_Categories.Count)
throw new ParameterValidationException("Number of options must be equal to the number of probabilities");
void NormalizeProbabilities()
var totalProbability = 0f;
for (var i = 0; i < probabilities.Count; i++)
var probability = probabilities[i];
if (probability < 0f)
throw new ParameterValidationException($"Found negative probability at index {i}");
totalProbability += probability;
if (totalProbability <= 0f)
throw new ParameterValidationException("Total probability must be greater than 0");
var sum = 0f;
m_NormalizedProbabilities = new float[probabilities.Count];
for (var i = 0; i < probabilities.Count; i++)
sum += probabilities[i] / totalProbability;
m_NormalizedProbabilities[i] = sum;
int BinarySearch(float key) {
var minNum = 0;
var maxNum = m_NormalizedProbabilities.Length - 1;
while (minNum <= maxNum) {
var mid = (minNum + maxNum) / 2;
// ReSharper disable once CompareOfFloatsByEqualityOperator
if (key == m_NormalizedProbabilities[mid]) {
return ++mid;
if (key < m_NormalizedProbabilities[mid]) {
maxNum = mid - 1;
else {
minNum = mid + 1;
return minNum;
/// Generates a sample
/// The generated sample
public T Sample()
var randomValue = m_Sampler.Sample();
return uniform
? m_Categories[(int)(randomValue * m_Categories.Count)]
: m_Categories[BinarySearch(randomValue)];
/// Generates a generic sample
/// The generated sample
public override object GenericSample()
return Sample();