using System; using System.Collections.Generic; using UnityEngine.Perception.Randomization.Samplers; namespace UnityEngine.Perception.Randomization.Parameters { /// /// Generates samples by choosing one option from a list of choices /// /// The sample type of the categorical parameter [Serializable] public abstract class CategoricalParameter : CategoricalParameterBase { [SerializeField] internal bool uniform; [SerializeReference] ISampler m_Sampler = new UniformSampler(0f, 1f); [SerializeField] List m_Categories = new List(); float[] m_NormalizedProbabilities; /// /// Returns a list containing the samplers attached to this parameter /// public override ISampler[] samplers => new [] { m_Sampler }; /// /// The sample type generated by this parameter /// public sealed override Type sampleType => typeof(T); /// /// Returns the category stored at the specified index /// /// The index of the category to lookup /// The category stored at the specified index public T GetCategory(int index) => m_Categories[index]; /// /// Returns the probability value stored at the specified index /// /// The index of the probability value to lookup /// The probability value stored at the specified index public float GetProbability(int index) => probabilities[index]; /// /// Constructs a new categorical parameter /// protected CategoricalParameter() { } /// /// Create a new categorical parameter from a list of categories with uniform probabilities /// /// List of categories /// protected CategoricalParameter(IEnumerable categoricalOptions) { if (categories.Count == 0) throw new ArgumentException("List of options is empty"); uniform = true; foreach (var option in categoricalOptions) AddOption(option, 1f); } /// /// Creates a new categorical parameter from a list of categories and their associated probabilities /// /// List of categories and their associated probabilities /// protected CategoricalParameter(IEnumerable<(T, float)> categoricalOptions) { if (categories.Count == 0) throw new ArgumentException("List of options is empty"); foreach (var (category, probability) in categoricalOptions) AddOption(category, probability); NormalizeProbabilities(); } internal override void AddOption() { m_Categories.Add(default); probabilities.Add(0f); } internal void AddOption(T option, float probability) { m_Categories.Add(option); probabilities.Add(probability); } internal override void RemoveOption(int index) { m_Categories.RemoveAt(index); probabilities.RemoveAt(index); } internal override void ClearOptions() { m_Categories.Clear(); probabilities.Clear(); } /// /// Returns a list of the potential categories this parameter can generate /// public IReadOnlyList<(T, float)> categories { get { var catOptions = new List<(T, float)>(m_Categories.Count); for (var i = 0; i < catOptions.Count; i++) catOptions.Add((m_Categories[i], probabilities[i])); return catOptions; } } /// /// Validates the categorical probabilities assigned to this parameter /// /// internal override void Validate() { base.Validate(); if (!uniform) { if (probabilities.Count != m_Categories.Count) throw new ParameterValidationException("Number of options must be equal to the number of probabilities"); NormalizeProbabilities(); } } internal void NormalizeProbabilities() { var totalProbability = 0f; for (var i = 0; i < probabilities.Count; i++) { var probability = probabilities[i]; if (probability < 0f) throw new ParameterValidationException($"Found negative probability at index {i}"); totalProbability += probability; } if (totalProbability <= 0f) throw new ParameterValidationException("Total probability must be greater than 0"); var sum = 0f; m_NormalizedProbabilities = new float[probabilities.Count]; for (var i = 0; i < probabilities.Count; i++) { sum += probabilities[i] / totalProbability; m_NormalizedProbabilities[i] = sum; } } int BinarySearch(float key) { var minNum = 0; var maxNum = m_NormalizedProbabilities.Length - 1; while (minNum <= maxNum) { var mid = (minNum + maxNum) / 2; // ReSharper disable once CompareOfFloatsByEqualityOperator if (key == m_NormalizedProbabilities[mid]) { return ++mid; } if (key < m_NormalizedProbabilities[mid]) { maxNum = mid - 1; } else { minNum = mid + 1; } } return minNum; } /// /// Generates a sample /// /// The generated sample public T Sample() { var randomValue = m_Sampler.Sample(); return uniform ? m_Categories[(int)(randomValue * m_Categories.Count)] : m_Categories[BinarySearch(randomValue)]; } internal sealed override void ApplyToTarget(int seedOffset) { if (!hasTarget) return; target.ApplyValueToTarget(Sample()); } } }