using System; using System.Collections; using System.Collections.Generic; using UnityEngine; namespace Unity.MLAgents.Actuators { /// /// A class that manages the delegation of events, action buffers, and action mask for a list of IActuators. /// internal class ActuatorManager : IList { // IActuators managed by this object. IList m_Actuators; // An implementation of IDiscreteActionMask that allows for writing to it based on an offset. ActuatorDiscreteActionMask m_DiscreteActionMask; /// /// Flag used to check if our IActuators are ready for execution. /// /// bool m_ReadyForExecution; /// /// The sum of all of the discrete branches for all of the s in this manager. /// internal int SumOfDiscreteBranchSizes { get; private set; } /// /// The number of the discrete branches for all of the s in this manager. /// internal int NumDiscreteActions { get; private set; } /// /// The number of continuous actions for all of the s in this manager. /// internal int NumContinuousActions { get; private set; } /// /// Returns the total actions which is calculated by + . /// public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions; /// /// Gets the managed by this object. /// public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask; /// /// Returns the previously stored actions for the actuators in this list. /// public float[] StoredContinuousActions { get; private set; } /// /// Returns the previously stored actions for the actuators in this list. /// public int[] StoredDiscreteActions { get; private set; } /// /// Create an ActuatorList with a preset capacity. /// /// The capacity of the list to create. public ActuatorManager(int capacity = 0) { m_Actuators = new List(capacity); } /// /// /// void ReadyActuatorsForExecution() { ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes, NumDiscreteActions); } /// /// This method validates that all s have unique names and equivalent action space types /// if the `DEBUG` preprocessor macro is defined, and allocates the appropriate buffers to manage the actions for /// all of the s that may live on a particular object. /// /// The list of actuators to validate and allocate buffers for. /// The total number of continuous actions for all of the actuators. /// The total sum of the discrete branches for all of the actuators in order /// to be able to allocate an . /// The number of discrete branches for all of the actuators. internal void ReadyActuatorsForExecution(IList actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches) { if (m_ReadyForExecution) { return; } #if DEBUG // Make sure the names are actually unique // Make sure all Actuators have the same SpaceType ValidateActuators(); #endif // Sort the Actuators by name to ensure determinism SortActuators(); StoredContinuousActions = numContinuousActions == 0 ? Array.Empty() : new float[numContinuousActions]; StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty() : new int[numDiscreteBranches]; m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches); m_ReadyForExecution = true; } /// /// Updates the local action buffer with the action buffer passed in. If the buffer /// passed in is null, the local action buffer will be cleared. /// /// The action buffer which contains all of the /// continuous actions for the IActuators in this list. /// The action buffer which contains all of the /// discrete actions for the IActuators in this list. public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer) { ReadyActuatorsForExecution(); UpdateActionArray(continuousActionBuffer, StoredContinuousActions); UpdateActionArray(discreteActionBuffer, StoredDiscreteActions); } static void UpdateActionArray(T[] sourceActionBuffer, T[] destination) { if (sourceActionBuffer == null || sourceActionBuffer.Length == 0) { Array.Clear(destination, 0, destination.Length); } else { Debug.Assert(sourceActionBuffer.Length == destination.Length, $"sourceActionBuffer:{sourceActionBuffer.Length} is a different" + $" size than destination: {destination.Length}."); Array.Copy(sourceActionBuffer, destination, destination.Length); } } /// /// This method will trigger the writing to the by all of the actuators /// managed by this object. /// public void WriteActionMask() { ReadyActuatorsForExecution(); m_DiscreteActionMask.ResetMask(); var offset = 0; for (var i = 0; i < m_Actuators.Count; i++) { var actuator = m_Actuators[i]; m_DiscreteActionMask.CurrentBranchOffset = offset; actuator.WriteDiscreteActionMask(m_DiscreteActionMask); offset += actuator.ActionSpec.NumDiscreteActions; } } /// /// Iterates through all of the IActuators in this list and calls their /// method on them with the appropriate /// s depending on their . /// public void ExecuteActions() { ReadyActuatorsForExecution(); var continuousStart = 0; var discreteStart = 0; for (var i = 0; i < m_Actuators.Count; i++) { var actuator = m_Actuators[i]; var numContinuousActions = actuator.ActionSpec.NumContinuousActions; var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; var continuousActions = ActionSegment.Empty; if (numContinuousActions > 0) { continuousActions = new ActionSegment(StoredContinuousActions, continuousStart, numContinuousActions); } var discreteActions = ActionSegment.Empty; if (numDiscreteActions > 0) { discreteActions = new ActionSegment(StoredDiscreteActions, discreteStart, numDiscreteActions); } actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions)); continuousStart += numContinuousActions; discreteStart += numDiscreteActions; } } /// /// Resets the and buffers to be all /// zeros and calls on each managed by this object. /// public void ResetData() { if (!m_ReadyForExecution) { return; } Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length); Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length); for (var i = 0; i < m_Actuators.Count; i++) { m_Actuators[i].ResetData(); } } /// /// Sorts the s according to their value. /// void SortActuators() { ((List)m_Actuators).Sort((x, y) => x.Name .CompareTo(y.Name)); } /// /// Validates that the IActuators managed by this object have unique names and equivalent action space types. /// Each Actuator needs to have a unique name in order for this object to ensure that the storage of action /// buffers, and execution of Actuators remains deterministic across different sessions of running. /// void ValidateActuators() { for (var i = 0; i < m_Actuators.Count - 1; i++) { Debug.Assert( !m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name), "Actuator names must be unique."); var first = m_Actuators[i].ActionSpec; var second = m_Actuators[i + 1].ActionSpec; Debug.Assert(first.NumContinuousActions > 0 == second.NumContinuousActions > 0, "Actuators on the same Agent must have the same action SpaceType."); } } /// /// Helper method to update bookkeeping items around buffer management for actuators added to this object. /// /// The IActuator to keep bookkeeping for. void AddToBufferSizes(IActuator actuatorItem) { if (actuatorItem == null) { return; } NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions; NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions; SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; } /// /// Helper method to update bookkeeping items around buffer management for actuators removed from this object. /// /// The IActuator to keep bookkeeping for. void SubtractFromBufferSize(IActuator actuatorItem) { if (actuatorItem == null) { return; } NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions; NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions; SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; } /// /// Sets all of the bookkeeping items back to 0. /// void ClearBufferSizes() { NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0; } /********************************************************************************* * IList implementation that delegates to m_Actuators List. * *********************************************************************************/ /// /// /// public IEnumerator GetEnumerator() { return m_Actuators.GetEnumerator(); } /// /// /// IEnumerator IEnumerable.GetEnumerator() { return ((IEnumerable)m_Actuators).GetEnumerator(); } /// /// /// /// public void Add(IActuator item) { Debug.Assert(m_ReadyForExecution == false, "Cannot add to the ActuatorManager after its buffers have been initialized"); m_Actuators.Add(item); AddToBufferSizes(item); } /// /// /// public void Clear() { Debug.Assert(m_ReadyForExecution == false, "Cannot clear the ActuatorManager after its buffers have been initialized"); m_Actuators.Clear(); ClearBufferSizes(); } /// /// /// public bool Contains(IActuator item) { return m_Actuators.Contains(item); } /// /// /// public void CopyTo(IActuator[] array, int arrayIndex) { m_Actuators.CopyTo(array, arrayIndex); } /// /// /// public bool Remove(IActuator item) { Debug.Assert(m_ReadyForExecution == false, "Cannot remove from the ActuatorManager after its buffers have been initialized"); if (m_Actuators.Remove(item)) { SubtractFromBufferSize(item); return true; } return false; } /// /// /// public int Count => m_Actuators.Count; /// /// /// public bool IsReadOnly => m_Actuators.IsReadOnly; /// /// /// public int IndexOf(IActuator item) { return m_Actuators.IndexOf(item); } /// /// /// public void Insert(int index, IActuator item) { Debug.Assert(m_ReadyForExecution == false, "Cannot insert into the ActuatorManager after its buffers have been initialized"); m_Actuators.Insert(index, item); AddToBufferSizes(item); } /// /// /// public void RemoveAt(int index) { Debug.Assert(m_ReadyForExecution == false, "Cannot remove from the ActuatorManager after its buffers have been initialized"); var actuator = m_Actuators[index]; SubtractFromBufferSize(actuator); m_Actuators.RemoveAt(index); } /// /// /// public IActuator this[int index] { get => m_Actuators[index]; set { Debug.Assert(m_ReadyForExecution == false, "Cannot modify the ActuatorManager after its buffers have been initialized"); var old = m_Actuators[index]; SubtractFromBufferSize(old); m_Actuators[index] = value; AddToBufferSizes(value); } } } }