using System; using System.Collections; using System.Collections.Generic; using UnityEngine; namespace Unity.MLAgents.Actuators { /// /// A class that manages the delegation of events, action buffers, and action mask for a list of IActuators. /// internal class ActuatorManager : IList { // IActuators managed by this object. IList m_Actuators; // An implementation of IDiscreteActionMask that allows for writing to it based on an offset. ActuatorDiscreteActionMask m_DiscreteActionMask; ActionSpec m_CombinedActionSpec; /// /// Flag used to check if our IActuators are ready for execution. /// /// bool m_ReadyForExecution; /// /// The sum of all of the discrete branches for all of the s in this manager. /// internal int SumOfDiscreteBranchSizes { get; private set; } /// /// The number of the discrete branches for all of the s in this manager. /// internal int NumDiscreteActions { get; private set; } /// /// The number of continuous actions for all of the s in this manager. /// internal int NumContinuousActions { get; private set; } /// /// Returns the total actions which is calculated by + . /// public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions; /// /// Gets the managed by this object. /// public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask; /// /// The currently stored object for the s managed by this class. /// public ActionBuffers StoredActions { get; private set; } /// /// Create an ActuatorList with a preset capacity. /// /// The capacity of the list to create. public ActuatorManager(int capacity = 0) { m_Actuators = new List(capacity); } /// /// /// void ReadyActuatorsForExecution() { ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes, NumDiscreteActions); } /// /// This method validates that all s have unique names /// if the `DEBUG` preprocessor macro is defined, and allocates the appropriate buffers to manage the actions for /// all of the s that may live on a particular object. /// /// The list of actuators to validate and allocate buffers for. /// The total number of continuous actions for all of the actuators. /// The total sum of the discrete branches for all of the actuators in order /// to be able to allocate an . /// The number of discrete branches for all of the actuators. internal void ReadyActuatorsForExecution(IList actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches) { if (m_ReadyForExecution) { return; } #if DEBUG // Make sure the names are actually unique ValidateActuators(); #endif // Sort the Actuators by name to ensure determinism SortActuators(); var continuousActions = numContinuousActions == 0 ? ActionSegment.Empty : new ActionSegment(new float[numContinuousActions]); var discreteActions = numDiscreteBranches == 0 ? ActionSegment.Empty : new ActionSegment(new int[numDiscreteBranches]); StoredActions = new ActionBuffers(continuousActions, discreteActions); m_CombinedActionSpec = CombineActionSpecs(actuators); m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches, m_CombinedActionSpec.BranchSizes); m_ReadyForExecution = true; } internal static ActionSpec CombineActionSpecs(IList actuators) { int numContinuousActions = 0; int numDiscreteActions = 0; foreach (var actuator in actuators) { numContinuousActions += actuator.ActionSpec.NumContinuousActions; numDiscreteActions += actuator.ActionSpec.NumDiscreteActions; } int[] combinedBranchSizes; if (numDiscreteActions == 0) { combinedBranchSizes = Array.Empty(); } else { combinedBranchSizes = new int[numDiscreteActions]; var start = 0; for (var i = 0; i < actuators.Count; i++) { var branchSizes = actuators[i].ActionSpec.BranchSizes; if (branchSizes != null) { Array.Copy(branchSizes, 0, combinedBranchSizes, start, branchSizes.Length); start += branchSizes.Length; } } } return new ActionSpec(numContinuousActions, combinedBranchSizes); } /// /// Returns an ActionSpec representing the concatenation of all IActuator's ActionSpecs /// /// public ActionSpec GetCombinedActionSpec() { ReadyActuatorsForExecution(); return m_CombinedActionSpec; } /// /// Updates the local action buffer with the action buffer passed in. If the buffer /// passed in is null, the local action buffer will be cleared. /// /// The object which contains all of the /// actions for the IActuators in this list. public void UpdateActions(ActionBuffers actions) { ReadyActuatorsForExecution(); UpdateActionArray(actions.ContinuousActions, StoredActions.ContinuousActions); UpdateActionArray(actions.DiscreteActions, StoredActions.DiscreteActions); } static void UpdateActionArray(ActionSegment sourceActionBuffer, ActionSegment destination) where T : struct { if (sourceActionBuffer.Length <= 0) { destination.Clear(); } else { Debug.AssertFormat(sourceActionBuffer.Length == destination.Length, "sourceActionBuffer: {0} is a different size than destination: {1}.", sourceActionBuffer.Length, destination.Length); Array.Copy(sourceActionBuffer.Array, sourceActionBuffer.Offset, destination.Array, destination.Offset, destination.Length); } } /// /// This method will trigger the writing to the by all of the actuators /// managed by this object. /// public void WriteActionMask() { ReadyActuatorsForExecution(); m_DiscreteActionMask.ResetMask(); var offset = 0; for (var i = 0; i < m_Actuators.Count; i++) { var actuator = m_Actuators[i]; if (actuator.ActionSpec.NumDiscreteActions > 0) { m_DiscreteActionMask.CurrentBranchOffset = offset; actuator.WriteDiscreteActionMask(m_DiscreteActionMask); offset += actuator.ActionSpec.NumDiscreteActions; } } } /// /// Iterates through all of the IActuators in this list and calls their /// method on them, if implemented, with the appropriate /// s depending on their . /// public void ApplyHeuristic(in ActionBuffers actionBuffersOut) { var continuousStart = 0; var discreteStart = 0; for (var i = 0; i < m_Actuators.Count; i++) { var actuator = m_Actuators[i]; var numContinuousActions = actuator.ActionSpec.NumContinuousActions; var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; if (numContinuousActions == 0 && numDiscreteActions == 0) { continue; } var continuousActions = ActionSegment.Empty; if (numContinuousActions > 0) { continuousActions = new ActionSegment(actionBuffersOut.ContinuousActions.Array, continuousStart, numContinuousActions); } var discreteActions = ActionSegment.Empty; if (numDiscreteActions > 0) { discreteActions = new ActionSegment(actionBuffersOut.DiscreteActions.Array, discreteStart, numDiscreteActions); } var heuristic = actuator as IHeuristicProvider; heuristic?.Heuristic(new ActionBuffers(continuousActions, discreteActions)); continuousStart += numContinuousActions; discreteStart += numDiscreteActions; } } /// /// Iterates through all of the IActuators in this list and calls their /// method on them with the appropriate /// s depending on their . /// public void ExecuteActions() { ReadyActuatorsForExecution(); var continuousStart = 0; var discreteStart = 0; for (var i = 0; i < m_Actuators.Count; i++) { var actuator = m_Actuators[i]; var numContinuousActions = actuator.ActionSpec.NumContinuousActions; var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; var continuousActions = ActionSegment.Empty; if (numContinuousActions > 0) { continuousActions = new ActionSegment(StoredActions.ContinuousActions.Array, continuousStart, numContinuousActions); } var discreteActions = ActionSegment.Empty; if (numDiscreteActions > 0) { discreteActions = new ActionSegment(StoredActions.DiscreteActions.Array, discreteStart, numDiscreteActions); } actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions)); continuousStart += numContinuousActions; discreteStart += numDiscreteActions; } } /// /// Resets the to be all /// zeros and calls on each managed by this object. /// public void ResetData() { if (!m_ReadyForExecution) { return; } StoredActions.Clear(); for (var i = 0; i < m_Actuators.Count; i++) { m_Actuators[i].ResetData(); } m_DiscreteActionMask.ResetMask(); } /// /// Sorts the s according to their value. /// void SortActuators() { ((List)m_Actuators).Sort((x, y) => x.Name .CompareTo(y.Name)); } /// /// Validates that the IActuators managed by this object have unique names. /// Each Actuator needs to have a unique name in order for this object to ensure that the storage of action /// buffers, and execution of Actuators remains deterministic across different sessions of running. /// void ValidateActuators() { for (var i = 0; i < m_Actuators.Count - 1; i++) { Debug.Assert( !m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name), "Actuator names must be unique."); } } /// /// Helper method to update bookkeeping items around buffer management for actuators added to this object. /// /// The IActuator to keep bookkeeping for. void AddToBufferSizes(IActuator actuatorItem) { if (actuatorItem == null) { return; } NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions; NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions; SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; } /// /// Helper method to update bookkeeping items around buffer management for actuators removed from this object. /// /// The IActuator to keep bookkeeping for. void SubtractFromBufferSize(IActuator actuatorItem) { if (actuatorItem == null) { return; } NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions; NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions; SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; } /// /// Sets all of the bookkeeping items back to 0. /// void ClearBufferSizes() { NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0; } /// /// Add an array of s at once. /// /// The array of s to add. public void AddActuators(IActuator[] actuators) { for (var i = 0; i < actuators.Length; i++) { Add(actuators[i]); } } /********************************************************************************* * IList implementation that delegates to m_Actuators List. * *********************************************************************************/ /// public IEnumerator GetEnumerator() { return m_Actuators.GetEnumerator(); } /// IEnumerator IEnumerable.GetEnumerator() { return ((IEnumerable)m_Actuators).GetEnumerator(); } /// public void Add(IActuator item) { Debug.Assert(m_ReadyForExecution == false, "Cannot add to the ActuatorManager after its buffers have been initialized"); m_Actuators.Add(item); AddToBufferSizes(item); } /// public void Clear() { Debug.Assert(m_ReadyForExecution == false, "Cannot clear the ActuatorManager after its buffers have been initialized"); m_Actuators.Clear(); ClearBufferSizes(); } /// public bool Contains(IActuator item) { return m_Actuators.Contains(item); } /// public void CopyTo(IActuator[] array, int arrayIndex) { m_Actuators.CopyTo(array, arrayIndex); } /// public bool Remove(IActuator item) { Debug.Assert(m_ReadyForExecution == false, "Cannot remove from the ActuatorManager after its buffers have been initialized"); if (m_Actuators.Remove(item)) { SubtractFromBufferSize(item); return true; } return false; } /// public int Count => m_Actuators.Count; /// public bool IsReadOnly => false; /// public int IndexOf(IActuator item) { return m_Actuators.IndexOf(item); } /// public void Insert(int index, IActuator item) { Debug.Assert(m_ReadyForExecution == false, "Cannot insert into the ActuatorManager after its buffers have been initialized"); m_Actuators.Insert(index, item); AddToBufferSizes(item); } /// public void RemoveAt(int index) { Debug.Assert(m_ReadyForExecution == false, "Cannot remove from the ActuatorManager after its buffers have been initialized"); var actuator = m_Actuators[index]; SubtractFromBufferSize(actuator); m_Actuators.RemoveAt(index); } /// public IActuator this[int index] { get => m_Actuators[index]; set { Debug.Assert(m_ReadyForExecution == false, "Cannot modify the ActuatorManager after its buffers have been initialized"); var old = m_Actuators[index]; SubtractFromBufferSize(old); m_Actuators[index] = value; AddToBufferSizes(value); } } } }