using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Actuators
{
///
/// A class that manages the delegation of events, action buffers, and action mask for a list of IActuators.
///
internal class ActuatorManager : IList
{
// IActuators managed by this object.
IList m_Actuators;
// An implementation of IDiscreteActionMask that allows for writing to it based on an offset.
ActuatorDiscreteActionMask m_DiscreteActionMask;
///
/// Flag used to check if our IActuators are ready for execution.
///
///
bool m_ReadyForExecution;
///
/// The sum of all of the discrete branches for all of the s in this manager.
///
internal int SumOfDiscreteBranchSizes { get; private set; }
///
/// The number of the discrete branches for all of the s in this manager.
///
internal int NumDiscreteActions { get; private set; }
///
/// The number of continuous actions for all of the s in this manager.
///
internal int NumContinuousActions { get; private set; }
///
/// Returns the total actions which is calculated by + .
///
public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions;
///
/// Gets the managed by this object.
///
public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask;
///
/// Returns the previously stored actions for the actuators in this list.
///
public float[] StoredContinuousActions { get; private set; }
///
/// Returns the previously stored actions for the actuators in this list.
///
public int[] StoredDiscreteActions { get; private set; }
///
/// Create an ActuatorList with a preset capacity.
///
/// The capacity of the list to create.
public ActuatorManager(int capacity = 0)
{
m_Actuators = new List(capacity);
}
///
///
///
void ReadyActuatorsForExecution()
{
ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes,
NumDiscreteActions);
}
///
/// This method validates that all s have unique names and equivalent action space types
/// if the `DEBUG` preprocessor macro is defined, and allocates the appropriate buffers to manage the actions for
/// all of the s that may live on a particular object.
///
/// The list of actuators to validate and allocate buffers for.
/// The total number of continuous actions for all of the actuators.
/// The total sum of the discrete branches for all of the actuators in order
/// to be able to allocate an .
/// The number of discrete branches for all of the actuators.
internal void ReadyActuatorsForExecution(IList actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches)
{
if (m_ReadyForExecution)
{
return;
}
#if DEBUG
// Make sure the names are actually unique
// Make sure all Actuators have the same SpaceType
ValidateActuators();
#endif
// Sort the Actuators by name to ensure determinism
SortActuators();
StoredContinuousActions = numContinuousActions == 0 ? Array.Empty() : new float[numContinuousActions];
StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty() : new int[numDiscreteBranches];
m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches);
m_ReadyForExecution = true;
}
///
/// Updates the local action buffer with the action buffer passed in. If the buffer
/// passed in is null, the local action buffer will be cleared.
///
/// The action buffer which contains all of the
/// continuous actions for the IActuators in this list.
/// The action buffer which contains all of the
/// discrete actions for the IActuators in this list.
public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer)
{
ReadyActuatorsForExecution();
UpdateActionArray(continuousActionBuffer, StoredContinuousActions);
UpdateActionArray(discreteActionBuffer, StoredDiscreteActions);
}
static void UpdateActionArray(T[] sourceActionBuffer, T[] destination)
{
if (sourceActionBuffer == null || sourceActionBuffer.Length == 0)
{
Array.Clear(destination, 0, destination.Length);
}
else
{
Debug.Assert(sourceActionBuffer.Length == destination.Length,
$"sourceActionBuffer:{sourceActionBuffer.Length} is a different" +
$" size than destination: {destination.Length}.");
Array.Copy(sourceActionBuffer, destination, destination.Length);
}
}
///
/// This method will trigger the writing to the by all of the actuators
/// managed by this object.
///
public void WriteActionMask()
{
ReadyActuatorsForExecution();
m_DiscreteActionMask.ResetMask();
var offset = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
m_DiscreteActionMask.CurrentBranchOffset = offset;
actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
offset += actuator.ActionSpec.NumDiscreteActions;
}
}
///
/// Iterates through all of the IActuators in this list and calls their
/// method on them with the appropriate
/// s depending on their .
///
public void ExecuteActions()
{
ReadyActuatorsForExecution();
var continuousStart = 0;
var discreteStart = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
var numContinuousActions = actuator.ActionSpec.NumContinuousActions;
var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions;
var continuousActions = ActionSegment.Empty;
if (numContinuousActions > 0)
{
continuousActions = new ActionSegment(StoredContinuousActions,
continuousStart,
numContinuousActions);
}
var discreteActions = ActionSegment.Empty;
if (numDiscreteActions > 0)
{
discreteActions = new ActionSegment(StoredDiscreteActions,
discreteStart,
numDiscreteActions);
}
actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions));
continuousStart += numContinuousActions;
discreteStart += numDiscreteActions;
}
}
///
/// Resets the and buffers to be all
/// zeros and calls on each managed by this object.
///
public void ResetData()
{
if (!m_ReadyForExecution)
{
return;
}
Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length);
Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length);
for (var i = 0; i < m_Actuators.Count; i++)
{
m_Actuators[i].ResetData();
}
}
///
/// Sorts the s according to their value.
///
void SortActuators()
{
((List)m_Actuators).Sort((x,
y) => x.Name
.CompareTo(y.Name));
}
///
/// Validates that the IActuators managed by this object have unique names and equivalent action space types.
/// Each Actuator needs to have a unique name in order for this object to ensure that the storage of action
/// buffers, and execution of Actuators remains deterministic across different sessions of running.
///
void ValidateActuators()
{
for (var i = 0; i < m_Actuators.Count - 1; i++)
{
Debug.Assert(
!m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name),
"Actuator names must be unique.");
var first = m_Actuators[i].ActionSpec;
var second = m_Actuators[i + 1].ActionSpec;
Debug.Assert(first.NumContinuousActions > 0 == second.NumContinuousActions > 0,
"Actuators on the same Agent must have the same action SpaceType.");
}
}
///
/// Helper method to update bookkeeping items around buffer management for actuators added to this object.
///
/// The IActuator to keep bookkeeping for.
void AddToBufferSizes(IActuator actuatorItem)
{
if (actuatorItem == null)
{
return;
}
NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions;
NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions;
SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes;
}
///
/// Helper method to update bookkeeping items around buffer management for actuators removed from this object.
///
/// The IActuator to keep bookkeeping for.
void SubtractFromBufferSize(IActuator actuatorItem)
{
if (actuatorItem == null)
{
return;
}
NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions;
NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions;
SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes;
}
///
/// Sets all of the bookkeeping items back to 0.
///
void ClearBufferSizes()
{
NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0;
}
/*********************************************************************************
* IList implementation that delegates to m_Actuators List. *
*********************************************************************************/
///
///
///
public IEnumerator GetEnumerator()
{
return m_Actuators.GetEnumerator();
}
///
///
///
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)m_Actuators).GetEnumerator();
}
///
///
///
///
public void Add(IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot add to the ActuatorManager after its buffers have been initialized");
m_Actuators.Add(item);
AddToBufferSizes(item);
}
///
///
///
public void Clear()
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot clear the ActuatorManager after its buffers have been initialized");
m_Actuators.Clear();
ClearBufferSizes();
}
///
///
///
public bool Contains(IActuator item)
{
return m_Actuators.Contains(item);
}
///
///
///
public void CopyTo(IActuator[] array, int arrayIndex)
{
m_Actuators.CopyTo(array, arrayIndex);
}
///
///
///
public bool Remove(IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot remove from the ActuatorManager after its buffers have been initialized");
if (m_Actuators.Remove(item))
{
SubtractFromBufferSize(item);
return true;
}
return false;
}
///
///
///
public int Count => m_Actuators.Count;
///
///
///
public bool IsReadOnly => m_Actuators.IsReadOnly;
///
///
///
public int IndexOf(IActuator item)
{
return m_Actuators.IndexOf(item);
}
///
///
///
public void Insert(int index, IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot insert into the ActuatorManager after its buffers have been initialized");
m_Actuators.Insert(index, item);
AddToBufferSizes(item);
}
///
///
///
public void RemoveAt(int index)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot remove from the ActuatorManager after its buffers have been initialized");
var actuator = m_Actuators[index];
SubtractFromBufferSize(actuator);
m_Actuators.RemoveAt(index);
}
///
///
///
public IActuator this[int index]
{
get => m_Actuators[index];
set
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot modify the ActuatorManager after its buffers have been initialized");
var old = m_Actuators[index];
SubtractFromBufferSize(old);
m_Actuators[index] = value;
AddToBufferSizes(value);
}
}
}
}