using System; using System.Collections.Generic; using UnityEngine.Perception.Randomization.Samplers; namespace UnityEngine.Perception.Randomization.Parameters { /// /// Generates samples by choosing one option from a list of choices /// /// The sample type of the categorical parameter [Serializable] public abstract class CategoricalParameter : CategoricalParameterBase { [SerializeField] internal bool uniform = true; [SerializeReference] ISampler m_Sampler = new UniformSampler(0f, 1f); [SerializeField] List m_Categories = new List(); float[] m_NormalizedProbabilities; /// /// Returns an IEnumerable that iterates over each sampler field in this parameter /// internal override IEnumerable samplers { get { yield return m_Sampler; } } /// /// The sample type generated by this parameter /// public sealed override Type sampleType => typeof(T); /// /// Returns the number of stored categories /// /// The number of stored categories public int GetCategoryCount() => m_Categories.Count; /// /// Returns the category stored at the specified index /// /// The index of the category to lookup /// The category stored at the specified index public T GetCategory(int index) => m_Categories[index]; /// /// Returns the probability value stored at the specified index /// /// The index of the probability value to lookup /// The probability value stored at the specified index public float GetProbability(int index) => probabilities[index]; /// /// Updates this parameter's list of categorical options /// /// The categorical options to configure public void SetOptions(IEnumerable categoricalOptions) { m_Categories.Clear(); probabilities.Clear(); foreach (var category in categoricalOptions) AddOption(category, 1f); NormalizeProbabilities(); } /// /// Updates this parameter's list of categorical options /// /// The categorical options to configure public void SetOptions(IEnumerable<(T, float)> categoricalOptions) { m_Categories.Clear(); probabilities.Clear(); foreach (var (category, probability) in categoricalOptions) AddOption(category, probability); NormalizeProbabilities(); } void AddOption(T option, float probability) { m_Categories.Add(option); probabilities.Add(probability); } /// /// Returns a list of the potential categories this parameter can generate /// public IReadOnlyList<(T, float)> categories { get { var catOptions = new List<(T, float)>(m_Categories.Count); for (var i = 0; i < m_Categories.Count; i++) catOptions.Add((m_Categories[i], probabilities[i])); return catOptions; } } /// /// Validates the categorical probabilities assigned to this parameter /// /// public override void Validate() { base.Validate(); // Check for a non-zero amount of specified categories if (m_Categories.Count == 0) throw new ParameterValidationException("No options added to categorical parameter"); // Check for duplicate categories var uniqueCategories = new HashSet(); foreach (var option in m_Categories) if (uniqueCategories.Contains(option)) throw new ParameterValidationException("Duplicate categories"); else uniqueCategories.Add(option); // Check if the number of specified probabilities is different from the number of listed categories if (!uniform) { if (probabilities.Count != m_Categories.Count) throw new ParameterValidationException("Number of options must be equal to the number of probabilities"); NormalizeProbabilities(); } } void NormalizeProbabilities() { var totalProbability = 0f; for (var i = 0; i < probabilities.Count; i++) { var probability = probabilities[i]; if (probability < 0f) throw new ParameterValidationException($"Found negative probability at index {i}"); totalProbability += probability; } if (totalProbability <= 0f) throw new ParameterValidationException("Total probability must be greater than 0"); var sum = 0f; m_NormalizedProbabilities = new float[probabilities.Count]; for (var i = 0; i < probabilities.Count; i++) { sum += probabilities[i] / totalProbability; m_NormalizedProbabilities[i] = sum; } } int BinarySearch(float key) { var minNum = 0; var maxNum = m_NormalizedProbabilities.Length - 1; while (minNum <= maxNum) { var mid = (minNum + maxNum) / 2; // ReSharper disable once CompareOfFloatsByEqualityOperator if (key == m_NormalizedProbabilities[mid]) { return ++mid; } if (key < m_NormalizedProbabilities[mid]) { maxNum = mid - 1; } else { minNum = mid + 1; } } return minNum; } /// /// Generates a sample /// /// The generated sample public T Sample() { var randomValue = m_Sampler.Sample(); return uniform ? m_Categories[(int)(randomValue * m_Categories.Count)] : m_Categories[BinarySearch(randomValue)]; } /// /// Generates a generic sample /// /// The generated sample public override object GenericSample() { return Sample(); } } }