using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Actuators
{
///
/// A class that manages the delegation of events, action buffers, and action mask for a list of IActuators.
///
internal class ActuatorManager : IList
{
// IActuators managed by this object.
IList m_Actuators;
// An implementation of IDiscreteActionMask that allows for writing to it based on an offset.
ActuatorDiscreteActionMask m_DiscreteActionMask;
ActionSpec m_CombinedActionSpec;
///
/// Flag used to check if our IActuators are ready for execution.
///
///
bool m_ReadyForExecution;
///
/// The sum of all of the discrete branches for all of the s in this manager.
///
internal int SumOfDiscreteBranchSizes { get; private set; }
///
/// The number of the discrete branches for all of the s in this manager.
///
internal int NumDiscreteActions { get; private set; }
///
/// The number of continuous actions for all of the s in this manager.
///
internal int NumContinuousActions { get; private set; }
///
/// Returns the total actions which is calculated by + .
///
public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions;
///
/// Gets the managed by this object.
///
public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask;
///
/// The currently stored object for the s managed by this class.
///
public ActionBuffers StoredActions { get; private set; }
///
/// Create an ActuatorList with a preset capacity.
///
/// The capacity of the list to create.
public ActuatorManager(int capacity = 0)
{
m_Actuators = new List(capacity);
}
///
///
///
void ReadyActuatorsForExecution()
{
ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes,
NumDiscreteActions);
}
///
/// This method validates that all s have unique names
/// if the `DEBUG` preprocessor macro is defined, and allocates the appropriate buffers to manage the actions for
/// all of the s that may live on a particular object.
///
/// The list of actuators to validate and allocate buffers for.
/// The total number of continuous actions for all of the actuators.
/// The total sum of the discrete branches for all of the actuators in order
/// to be able to allocate an .
/// The number of discrete branches for all of the actuators.
internal void ReadyActuatorsForExecution(IList actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches)
{
if (m_ReadyForExecution)
{
return;
}
#if DEBUG
// Make sure the names are actually unique
ValidateActuators();
#endif
// Sort the Actuators by name to ensure determinism
SortActuators();
var continuousActions = numContinuousActions == 0 ? ActionSegment.Empty :
new ActionSegment(new float[numContinuousActions]);
var discreteActions = numDiscreteBranches == 0 ? ActionSegment.Empty : new ActionSegment(new int[numDiscreteBranches]);
StoredActions = new ActionBuffers(continuousActions, discreteActions);
m_CombinedActionSpec = CombineActionSpecs(actuators);
m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches, m_CombinedActionSpec.BranchSizes);
m_ReadyForExecution = true;
}
internal static ActionSpec CombineActionSpecs(IList actuators)
{
int numContinuousActions = 0;
int numDiscreteActions = 0;
foreach (var actuator in actuators)
{
numContinuousActions += actuator.ActionSpec.NumContinuousActions;
numDiscreteActions += actuator.ActionSpec.NumDiscreteActions;
}
int[] combinedBranchSizes;
if (numDiscreteActions == 0)
{
combinedBranchSizes = Array.Empty();
}
else
{
combinedBranchSizes = new int[numDiscreteActions];
var start = 0;
for (var i = 0; i < actuators.Count; i++)
{
var branchSizes = actuators[i].ActionSpec.BranchSizes;
if (branchSizes != null)
{
Array.Copy(branchSizes, 0, combinedBranchSizes, start, branchSizes.Length);
start += branchSizes.Length;
}
}
}
return new ActionSpec(numContinuousActions, combinedBranchSizes);
}
///
/// Returns an ActionSpec representing the concatenation of all IActuator's ActionSpecs
///
///
public ActionSpec GetCombinedActionSpec()
{
ReadyActuatorsForExecution();
return m_CombinedActionSpec;
}
///
/// Updates the local action buffer with the action buffer passed in. If the buffer
/// passed in is null, the local action buffer will be cleared.
///
/// The object which contains all of the
/// actions for the IActuators in this list.
public void UpdateActions(ActionBuffers actions)
{
ReadyActuatorsForExecution();
UpdateActionArray(actions.ContinuousActions, StoredActions.ContinuousActions);
UpdateActionArray(actions.DiscreteActions, StoredActions.DiscreteActions);
}
static void UpdateActionArray(ActionSegment sourceActionBuffer, ActionSegment destination)
where T : struct
{
if (sourceActionBuffer.Length <= 0)
{
destination.Clear();
}
else
{
Debug.AssertFormat(sourceActionBuffer.Length == destination.Length,
"sourceActionBuffer: {0} is a different size than destination: {1}.",
sourceActionBuffer.Length,
destination.Length);
Array.Copy(sourceActionBuffer.Array,
sourceActionBuffer.Offset,
destination.Array,
destination.Offset,
destination.Length);
}
}
///
/// This method will trigger the writing to the by all of the actuators
/// managed by this object.
///
public void WriteActionMask()
{
ReadyActuatorsForExecution();
m_DiscreteActionMask.ResetMask();
var offset = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
if (actuator.ActionSpec.NumDiscreteActions > 0)
{
m_DiscreteActionMask.CurrentBranchOffset = offset;
actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
offset += actuator.ActionSpec.NumDiscreteActions;
}
}
}
///
/// Iterates through all of the IActuators in this list and calls their
/// method on them, if implemented, with the appropriate
/// s depending on their .
///
public void ApplyHeuristic(in ActionBuffers actionBuffersOut)
{
var continuousStart = 0;
var discreteStart = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
var numContinuousActions = actuator.ActionSpec.NumContinuousActions;
var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions;
if (numContinuousActions == 0 && numDiscreteActions == 0)
{
continue;
}
var continuousActions = ActionSegment.Empty;
if (numContinuousActions > 0)
{
continuousActions = new ActionSegment(actionBuffersOut.ContinuousActions.Array,
continuousStart,
numContinuousActions);
}
var discreteActions = ActionSegment.Empty;
if (numDiscreteActions > 0)
{
discreteActions = new ActionSegment(actionBuffersOut.DiscreteActions.Array,
discreteStart,
numDiscreteActions);
}
var heuristic = actuator as IHeuristicProvider;
heuristic?.Heuristic(new ActionBuffers(continuousActions, discreteActions));
continuousStart += numContinuousActions;
discreteStart += numDiscreteActions;
}
}
///
/// Iterates through all of the IActuators in this list and calls their
/// method on them with the appropriate
/// s depending on their .
///
public void ExecuteActions()
{
ReadyActuatorsForExecution();
var continuousStart = 0;
var discreteStart = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
var numContinuousActions = actuator.ActionSpec.NumContinuousActions;
var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions;
var continuousActions = ActionSegment.Empty;
if (numContinuousActions > 0)
{
continuousActions = new ActionSegment(StoredActions.ContinuousActions.Array,
continuousStart,
numContinuousActions);
}
var discreteActions = ActionSegment.Empty;
if (numDiscreteActions > 0)
{
discreteActions = new ActionSegment(StoredActions.DiscreteActions.Array,
discreteStart,
numDiscreteActions);
}
actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions));
continuousStart += numContinuousActions;
discreteStart += numDiscreteActions;
}
}
///
/// Resets the to be all
/// zeros and calls on each managed by this object.
///
public void ResetData()
{
if (!m_ReadyForExecution)
{
return;
}
StoredActions.Clear();
for (var i = 0; i < m_Actuators.Count; i++)
{
m_Actuators[i].ResetData();
}
m_DiscreteActionMask.ResetMask();
}
///
/// Sorts the s according to their value.
///
void SortActuators()
{
((List)m_Actuators).Sort((x,
y) => x.Name
.CompareTo(y.Name));
}
///
/// Validates that the IActuators managed by this object have unique names.
/// Each Actuator needs to have a unique name in order for this object to ensure that the storage of action
/// buffers, and execution of Actuators remains deterministic across different sessions of running.
///
void ValidateActuators()
{
for (var i = 0; i < m_Actuators.Count - 1; i++)
{
Debug.Assert(
!m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name),
"Actuator names must be unique.");
}
}
///
/// Helper method to update bookkeeping items around buffer management for actuators added to this object.
///
/// The IActuator to keep bookkeeping for.
void AddToBufferSizes(IActuator actuatorItem)
{
if (actuatorItem == null)
{
return;
}
NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions;
NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions;
SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes;
}
///
/// Helper method to update bookkeeping items around buffer management for actuators removed from this object.
///
/// The IActuator to keep bookkeeping for.
void SubtractFromBufferSize(IActuator actuatorItem)
{
if (actuatorItem == null)
{
return;
}
NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions;
NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions;
SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes;
}
///
/// Sets all of the bookkeeping items back to 0.
///
void ClearBufferSizes()
{
NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0;
}
/*********************************************************************************
* IList implementation that delegates to m_Actuators List. *
*********************************************************************************/
///
public IEnumerator GetEnumerator()
{
return m_Actuators.GetEnumerator();
}
///
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)m_Actuators).GetEnumerator();
}
///
public void Add(IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot add to the ActuatorManager after its buffers have been initialized");
m_Actuators.Add(item);
AddToBufferSizes(item);
}
///
public void Clear()
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot clear the ActuatorManager after its buffers have been initialized");
m_Actuators.Clear();
ClearBufferSizes();
}
///
public bool Contains(IActuator item)
{
return m_Actuators.Contains(item);
}
///
public void CopyTo(IActuator[] array, int arrayIndex)
{
m_Actuators.CopyTo(array, arrayIndex);
}
///
public bool Remove(IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot remove from the ActuatorManager after its buffers have been initialized");
if (m_Actuators.Remove(item))
{
SubtractFromBufferSize(item);
return true;
}
return false;
}
///
public int Count => m_Actuators.Count;
///
public bool IsReadOnly => m_Actuators.IsReadOnly;
///
public int IndexOf(IActuator item)
{
return m_Actuators.IndexOf(item);
}
///
public void Insert(int index, IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot insert into the ActuatorManager after its buffers have been initialized");
m_Actuators.Insert(index, item);
AddToBufferSizes(item);
}
///
public void RemoveAt(int index)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot remove from the ActuatorManager after its buffers have been initialized");
var actuator = m_Actuators[index];
SubtractFromBufferSize(actuator);
m_Actuators.RemoveAt(index);
}
///
public IActuator this[int index]
{
get => m_Actuators[index];
set
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot modify the ActuatorManager after its buffers have been initialized");
var old = m_Actuators[index];
SubtractFromBufferSize(old);
m_Actuators[index] = value;
AddToBufferSizes(value);
}
}
}
}