浏览代码

Add internal API for Actuators along with their tests. (#4297)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
323d911b
共有 33 个文件被更改,包括 1744 次插入1 次删除
  1. 2
      Project/ProjectSettings/ProjectVersion.txt
  2. 8
      com.unity.ml-agents/Runtime/Actuators.meta
  3. 8
      com.unity.ml-agents/Tests/Editor/Actuators.meta
  4. 181
      com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
  5. 3
      com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta
  6. 75
      com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
  7. 3
      com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta
  8. 17
      com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs
  9. 3
      com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta
  10. 150
      com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs
  11. 3
      com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta
  12. 415
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
  13. 3
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta
  14. 101
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
  15. 3
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta
  16. 21
      com.unity.ml-agents/Runtime/Actuators/IActuator.cs
  17. 3
      com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta
  18. 38
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
  19. 3
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta
  20. 72
      com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
  21. 3
      com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta
  22. 55
      com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs
  23. 3
      com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs.meta
  24. 114
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs
  25. 3
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs.meta
  26. 310
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
  27. 3
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs.meta
  28. 38
      com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
  29. 3
      com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs.meta
  30. 98
      com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs
  31. 3
      com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs.meta

2
Project/ProjectSettings/ProjectVersion.txt


m_EditorVersion: 2018.4.17f1
m_EditorVersion: 2018.4.24f1

8
com.unity.ml-agents/Runtime/Actuators.meta


fileFormatVersion: 2
guid: 26733e59183b6479e8f0e892a8bf09a4
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

8
com.unity.ml-agents/Tests/Editor/Actuators.meta


fileFormatVersion: 2
guid: c7e705f7d549e43c6be18ae809cd6f54
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

181
com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs


using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// ActionSegment{T} is a data structure that allows access to a segment of an underlying array
/// in order to avoid the copying and allocation of sub-arrays. The segment is defined by
/// the offset into the original array, and an length.
/// </summary>
/// <typeparam name="T">The type of object stored in the underlying <see cref="Array"/></typeparam>
internal readonly struct ActionSegment<T> : IEnumerable<T>, IEquatable<ActionSegment<T>>
where T : struct
{
/// <summary>
/// The zero-based offset into the original array at which this segment starts.
/// </summary>
public readonly int Offset;
/// <summary>
/// The number of items this segment can access in the underlying array.
/// </summary>
public readonly int Length;
/// <summary>
/// An Empty segment which has an offset of 0, a Length of 0, and it's underlying array
/// is also empty.
/// </summary>
public static ActionSegment<T> Empty = new ActionSegment<T>(System.Array.Empty<T>(), 0, 0);
static void CheckParameters(T[] actionArray, int offset, int length)
{
#if DEBUG
if (offset + length > actionArray.Length)
{
throw new ArgumentOutOfRangeException(nameof(offset),
$"Arguments offset: {offset} and length: {length} " +
$"are out of bounds of actionArray: {actionArray.Length}.");
}
#endif
}
/// <summary>
/// Construct an <see cref="ActionSegment{T}"/> with an underlying array
/// and offset, and a length.
/// </summary>
/// <param name="actionArray">The underlying array which this segment has a view into</param>
/// <param name="offset">The zero-based offset into the underlying array.</param>
/// <param name="length">The length of the segment.</param>
public ActionSegment(T[] actionArray, int offset, int length)
{
CheckParameters(actionArray, offset, length);
Array = actionArray;
Offset = offset;
Length = length;
}
/// <summary>
/// Get the underlying <see cref="Array"/> of this segment.
/// </summary>
public T[] Array { get; }
/// <summary>
/// Allows access to the underlying array using array syntax.
/// </summary>
/// <param name="index">The zero-based index of the segment.</param>
/// <exception cref="IndexOutOfRangeException">Thrown when the index is less than 0 or
/// greater than or equal to <see cref="Length"/></exception>
public T this[int index]
{
get
{
if (index < 0 || index > Length)
{
throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}");
}
return Array[Offset + index];
}
}
/// <inheritdoc cref="IEnumerable{T}.GetEnumerator"/>
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{
return new Enumerator(this);
}
/// <inheritdoc cref="IEnumerable{T}"/>
public IEnumerator GetEnumerator()
{
return new Enumerator(this);
}
/// <inheritdoc cref="ValueType.Equals(object)"/>
public override bool Equals(object obj)
{
if (!(obj is ActionSegment<T>))
{
return false;
}
return Equals((ActionSegment<T>)obj);
}
/// <inheritdoc cref="IEquatable{T}.Equals(T)"/>
public bool Equals(ActionSegment<T> other)
{
return Offset == other.Offset && Length == other.Length && Equals(Array, other.Array);
}
/// <inheritdoc cref="ValueType.GetHashCode"/>
public override int GetHashCode()
{
unchecked
{
var hashCode = Offset;
hashCode = (hashCode * 397) ^ Length;
hashCode = (hashCode * 397) ^ (Array != null ? Array.GetHashCode() : 0);
return hashCode;
}
}
/// <summary>
/// A private <see cref="IEnumerator{T}"/> for the <see cref="ActionSegment{T}"/> value type which follows its
/// rules of being a view into an underlying <see cref="Array"/>.
/// </summary>
struct Enumerator : IEnumerator<T>
{
readonly T[] m_Array;
readonly int m_Start;
readonly int m_End; // cache Offset + Count, since it's a little slow
int m_Current;
internal Enumerator(ActionSegment<T> arraySegment)
{
Debug.Assert(arraySegment.Array != null);
Debug.Assert(arraySegment.Offset >= 0);
Debug.Assert(arraySegment.Length >= 0);
Debug.Assert(arraySegment.Offset + arraySegment.Length <= arraySegment.Array.Length);
m_Array = arraySegment.Array;
m_Start = arraySegment.Offset;
m_End = arraySegment.Offset + arraySegment.Length;
m_Current = arraySegment.Offset - 1;
}
public bool MoveNext()
{
if (m_Current < m_End)
{
m_Current++;
return m_Current < m_End;
}
return false;
}
public T Current
{
get
{
if (m_Current < m_Start)
throw new InvalidOperationException("Enumerator not started.");
if (m_Current >= m_End)
throw new InvalidOperationException("Enumerator has reached the end already.");
return m_Array[m_Current];
}
}
object IEnumerator.Current => Current;
void IEnumerator.Reset()
{
m_Current = m_Start - 1;
}
public void Dispose()
{
}
}
}
}

3
com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta


fileFormatVersion: 2
guid: 4fa1432c1ba3460caaa84303a9011ef2
timeCreated: 1595869823

75
com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs


using System;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents.Policies;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Defines the structure of an Action Space to be used by the Actuator system.
/// </summary>
internal readonly struct ActionSpec
{
/// <summary>
/// An array of branch sizes for our action space.
///
/// For an IActuator that uses a Discrete <see cref="SpaceType"/>, the number of
/// branches is the Length of the Array and each index contains the branch size.
/// The cumulative sum of the total number of discrete actions can be retrieved
/// by the <see cref="SumOfDiscreteBranchSizes"/> property.
///
/// For an IActuator with a Continuous it will be null.
/// </summary>
public readonly int[] BranchSizes;
/// <summary>
/// The number of actions for a Continuous <see cref="SpaceType"/>.
/// </summary>
public int NumContinuousActions { get; }
/// <summary>
/// The number of branches for a Discrete <see cref="SpaceType"/>.
/// </summary>
public int NumDiscreteActions { get; }
/// <summary>
/// Get the total number of Discrete Actions that can be taken by calculating the Sum
/// of all of the Discrete Action branch sizes.
/// </summary>
public int SumOfDiscreteBranchSizes { get; }
/// <summary>
/// Creates a Continuous <see cref="ActionSpec"/> with the number of actions available.
/// </summary>
/// <param name="numActions">The number of actions available.</param>
/// <returns>An Continuous ActionSpec initialized with the number of actions available.</returns>
public static ActionSpec MakeContinuous(int numActions)
{
var actuatorSpace = new ActionSpec(numActions, 0);
return actuatorSpace;
}
/// <summary>
/// Creates a Discrete <see cref="ActionSpec"/> with the array of branch sizes that
/// represents the action space.
/// </summary>
/// <param name="branchSizes">The array of branch sizes for the discrete action space. Each index
/// contains the number of actions available for that branch.</param>
/// <returns>An Discrete ActionSpec initialized with the array of branch sizes.</returns>
public static ActionSpec MakeDiscrete(int[] branchSizes)
{
var numActions = branchSizes.Length;
var actuatorSpace = new ActionSpec(0, numActions, branchSizes);
return actuatorSpace;
}
ActionSpec(int numContinuousActions, int numDiscreteActions, int[] branchSizes = null)
{
NumContinuousActions = numContinuousActions;
NumDiscreteActions = numDiscreteActions;
BranchSizes = branchSizes;
SumOfDiscreteBranchSizes = branchSizes?.Sum() ?? 0;
}
}
}

3
com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta


fileFormatVersion: 2
guid: ecdd6deefba1416ca149fe09d2a5afd8
timeCreated: 1595892361

17
com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs


using UnityEngine;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Editor components for creating Actuators. Generally an IActuator component should
/// have a corresponding ActuatorComponent.
/// </summary>
internal abstract class ActuatorComponent : MonoBehaviour
{
/// <summary>
/// Create the IActuator. This is called by the Agent when it is initialized.
/// </summary>
/// <returns>Created IActuator object.</returns>
public abstract IActuator CreateActuator();
}
}

3
com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta


fileFormatVersion: 2
guid: 77cefae5f6d841be9ff80b41293d271b
timeCreated: 1593017318

150
com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs


using System;
using System.Collections.Generic;
using System.Linq;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Implementation of IDiscreteActionMask that allows writing to the action mask from an <see cref="IActuator"/>.
/// </summary>
internal class ActuatorDiscreteActionMask : IDiscreteActionMask
{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.
int[] m_StartingActionIndices;
int[] m_BranchSizes;
bool[] m_CurrentMask;
IList<IActuator> m_Actuators;
readonly int m_SumOfDiscreteBranchSizes;
readonly int m_NumBranches;
/// <summary>
/// The offset into the branches array that is used when actuators are writing to the action mask.
/// </summary>
public int CurrentBranchOffset { get; set; }
internal ActuatorDiscreteActionMask(IList<IActuator> actuators, int sumOfDiscreteBranchSizes, int numBranches)
{
m_Actuators = actuators;
m_SumOfDiscreteBranchSizes = sumOfDiscreteBranchSizes;
m_NumBranches = numBranches;
}
/// <inheritdoc cref="IDiscreteActionMask.WriteMask"/>
public void WriteMask(int branch, IEnumerable<int> actionIndices)
{
LazyInitialize();
// Perform the masking
foreach (var actionIndex in actionIndices)
{
#if DEBUG
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
#endif
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = true;
}
}
void LazyInitialize()
{
if (m_BranchSizes == null)
{
m_BranchSizes = new int[m_NumBranches];
var start = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
var branchSizes = actuator.ActionSpec.BranchSizes;
Array.Copy(branchSizes, 0, m_BranchSizes, start, branchSizes.Length);
start += branchSizes.Length;
}
}
// By default, the masks are null. If we want to specify a new mask, we initialize
// the actionMasks with trues.
if (m_CurrentMask == null)
{
m_CurrentMask = new bool[m_SumOfDiscreteBranchSizes];
}
// If this is the first time the masked actions are used, we generate the starting
// indices for each branch.
if (m_StartingActionIndices == null)
{
m_StartingActionIndices = Utilities.CumSum(m_BranchSizes);
}
}
/// <inheritdoc cref="IDiscreteActionMask.GetMask"/>
public bool[] GetMask()
{
#if DEBUG
if (m_CurrentMask != null)
{
AssertMask();
}
#endif
return m_CurrentMask;
}
/// <summary>
/// Makes sure that the current mask is usable.
/// </summary>
void AssertMask()
{
#if DEBUG
for (var branchIndex = 0; branchIndex < m_NumBranches; branchIndex++)
{
if (AreAllActionsMasked(branchIndex))
{
throw new UnityAgentsException(
"Invalid Action Masking : All the actions of branch " + branchIndex +
" are masked.");
}
}
#endif
}
/// <summary>
/// Resets the current mask for an agent.
/// </summary>
public void ResetMask()
{
if (m_CurrentMask != null)
{
Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length);
}
}
/// <summary>
/// Checks if all the actions in the input branch are masked.
/// </summary>
/// <param name="branch"> The index of the branch to check.</param>
/// <returns> True if all the actions of the branch are masked.</returns>
bool AreAllActionsMasked(int branch)
{
if (m_CurrentMask == null)
{
return false;
}
var start = m_StartingActionIndices[branch];
var end = m_StartingActionIndices[branch + 1];
for (var i = start; i < end; i++)
{
if (!m_CurrentMask[i])
{
return false;
}
}
return true;
}
}
}

3
com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta


fileFormatVersion: 2
guid: d2a19e2f43fd4637a38d42b2a5f989f3
timeCreated: 1595459316

415
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs


using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// A class that manages the delegation of events, action buffers, and action mask for a list of IActuators.
/// </summary>
internal class ActuatorManager : IList<IActuator>
{
// IActuators managed by this object.
IList<IActuator> m_Actuators;
// An implementation of IDiscreteActionMask that allows for writing to it based on an offset.
ActuatorDiscreteActionMask m_DiscreteActionMask;
/// <summary>
/// Flag used to check if our IActuators are ready for execution.
/// </summary>
/// <seealso cref="ReadyActuatorsForExecution(IList{IActuator}, int, int, int)"/>
bool m_ReadyForExecution;
/// <summary>
/// The sum of all of the discrete branches for all of the <see cref="IActuator"/>s in this manager.
/// </summary>
internal int SumOfDiscreteBranchSizes { get; private set; }
/// <summary>
/// The number of the discrete branches for all of the <see cref="IActuator"/>s in this manager.
/// </summary>
internal int NumDiscreteActions { get; private set; }
/// <summary>
/// The number of continuous actions for all of the <see cref="IActuator"/>s in this manager.
/// </summary>
internal int NumContinuousActions { get; private set; }
/// <summary>
/// Returns the total actions which is calculated by <see cref="NumContinuousActions"/> + <see cref="NumDiscreteActions"/>.
/// </summary>
public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions;
/// <summary>
/// Gets the <see cref="IDiscreteActionMask"/> managed by this object.
/// </summary>
public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask;
/// <summary>
/// Returns the previously stored actions for the actuators in this list.
/// </summary>
public float[] StoredContinuousActions { get; private set; }
/// <summary>
/// Returns the previously stored actions for the actuators in this list.
/// </summary>
public int[] StoredDiscreteActions { get; private set; }
/// <summary>
/// Create an ActuatorList with a preset capacity.
/// </summary>
/// <param name="capacity">The capacity of the list to create.</param>
public ActuatorManager(int capacity = 0)
{
m_Actuators = new List<IActuator>(capacity);
}
/// <summary>
/// <see cref="ReadyActuatorsForExecution(IList{IActuator}, int, int, int)"/>
/// </summary>
void ReadyActuatorsForExecution()
{
ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes,
NumDiscreteActions);
}
/// <summary>
/// This method validates that all <see cref="IActuator"/>s have unique names and equivalent action space types
/// if the `DEBUG` preprocessor macro is defined, and allocates the appropriate buffers to manage the actions for
/// all of the <see cref="IActuator"/>s that may live on a particular object.
/// </summary>
/// <param name="actuators">The list of actuators to validate and allocate buffers for.</param>
/// <param name="numContinuousActions">The total number of continuous actions for all of the actuators.</param>
/// <param name="sumOfDiscreteBranches">The total sum of the discrete branches for all of the actuators in order
/// to be able to allocate an <see cref="IDiscreteActionMask"/>.</param>
/// <param name="numDiscreteBranches">The number of discrete branches for all of the actuators.</param>
internal void ReadyActuatorsForExecution(IList<IActuator> actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches)
{
if (m_ReadyForExecution)
{
return;
}
#if DEBUG
// Make sure the names are actually unique
// Make sure all Actuators have the same SpaceType
ValidateActuators();
#endif
// Sort the Actuators by name to ensure determinism
SortActuators();
StoredContinuousActions = numContinuousActions == 0 ? Array.Empty<float>() : new float[numContinuousActions];
StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty<int>() : new int[numDiscreteBranches];
m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches);
m_ReadyForExecution = true;
}
/// <summary>
/// Updates the local action buffer with the action buffer passed in. If the buffer
/// passed in is null, the local action buffer will be cleared.
/// </summary>
/// <param name="continuousActionBuffer">The action buffer which contains all of the
/// continuous actions for the IActuators in this list.</param>
/// <param name="discreteActionBuffer">The action buffer which contains all of the
/// discrete actions for the IActuators in this list.</param>
public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer)
{
ReadyActuatorsForExecution();
UpdateActionArray(continuousActionBuffer, StoredContinuousActions);
UpdateActionArray(discreteActionBuffer, StoredDiscreteActions);
}
static void UpdateActionArray<T>(T[] sourceActionBuffer, T[] destination)
{
if (sourceActionBuffer == null || sourceActionBuffer.Length == 0)
{
Array.Clear(destination, 0, destination.Length);
}
else
{
Debug.Assert(sourceActionBuffer.Length == destination.Length,
$"sourceActionBuffer:{sourceActionBuffer.Length} is a different" +
$" size than destination: {destination.Length}.");
Array.Copy(sourceActionBuffer, destination, destination.Length);
}
}
/// <summary>
/// This method will trigger the writing to the <see cref="IDiscreteActionMask"/> by all of the actuators
/// managed by this object.
/// </summary>
public void WriteActionMask()
{
ReadyActuatorsForExecution();
m_DiscreteActionMask.ResetMask();
var offset = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
m_DiscreteActionMask.CurrentBranchOffset = offset;
actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
offset += actuator.ActionSpec.NumDiscreteActions;
}
}
/// <summary>
/// Iterates through all of the IActuators in this list and calls their
/// <see cref="IActionReceiver.OnActionReceived"/> method on them with the appropriate
/// <see cref="ActionSegment{T}"/>s depending on their <see cref="IActionReceiver.ActionSpec"/>.
/// </summary>
public void ExecuteActions()
{
ReadyActuatorsForExecution();
var continuousStart = 0;
var discreteStart = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
var numContinuousActions = actuator.ActionSpec.NumContinuousActions;
var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions;
var continuousActions = ActionSegment<float>.Empty;
if (numContinuousActions > 0)
{
continuousActions = new ActionSegment<float>(StoredContinuousActions,
continuousStart,
numContinuousActions);
}
var discreteActions = ActionSegment<int>.Empty;
if (numDiscreteActions > 0)
{
discreteActions = new ActionSegment<int>(StoredDiscreteActions,
discreteStart,
numDiscreteActions);
}
actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions));
continuousStart += numContinuousActions;
discreteStart += numDiscreteActions;
}
}
/// <summary>
/// Resets the <see cref="StoredContinuousActions"/> and <see cref="StoredDiscreteActions"/> buffers to be all
/// zeros and calls <see cref="IActuator.ResetData"/> on each <see cref="IActuator"/> managed by this object.
/// </summary>
public void ResetData()
{
if (!m_ReadyForExecution)
{
return;
}
Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length);
Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length);
for (var i = 0; i < m_Actuators.Count; i++)
{
m_Actuators[i].ResetData();
}
}
/// <summary>
/// Sorts the <see cref="IActuator"/>s according to their <see cref="IActuator.GetName"/> value.
/// </summary>
void SortActuators()
{
((List<IActuator>)m_Actuators).Sort((x,
y) => x.Name
.CompareTo(y.Name));
}
/// <summary>
/// Validates that the IActuators managed by this object have unique names and equivalent action space types.
/// Each Actuator needs to have a unique name in order for this object to ensure that the storage of action
/// buffers, and execution of Actuators remains deterministic across different sessions of running.
/// </summary>
void ValidateActuators()
{
for (var i = 0; i < m_Actuators.Count - 1; i++)
{
Debug.Assert(
!m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name),
"Actuator names must be unique.");
var first = m_Actuators[i].ActionSpec;
var second = m_Actuators[i + 1].ActionSpec;
Debug.Assert(first.NumContinuousActions > 0 == second.NumContinuousActions > 0,
"Actuators on the same Agent must have the same action SpaceType.");
}
}
/// <summary>
/// Helper method to update bookkeeping items around buffer management for actuators added to this object.
/// </summary>
/// <param name="actuatorItem">The IActuator to keep bookkeeping for.</param>
void AddToBufferSizes(IActuator actuatorItem)
{
if (actuatorItem == null)
{
return;
}
NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions;
NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions;
SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes;
}
/// <summary>
/// Helper method to update bookkeeping items around buffer management for actuators removed from this object.
/// </summary>
/// <param name="actuatorItem">The IActuator to keep bookkeeping for.</param>
void SubtractFromBufferSize(IActuator actuatorItem)
{
if (actuatorItem == null)
{
return;
}
NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions;
NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions;
SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes;
}
/// <summary>
/// Sets all of the bookkeeping items back to 0.
/// </summary>
void ClearBufferSizes()
{
NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0;
}
/*********************************************************************************
* IList implementation that delegates to m_Actuators List. *
*********************************************************************************/
/// <summary>
/// <inheritdoc cref="IEnumerable{T}.GetEnumerator"/>
/// </summary>
public IEnumerator<IActuator> GetEnumerator()
{
return m_Actuators.GetEnumerator();
}
/// <summary>
/// <inheritdoc cref="IList{T}.GetEnumerator"/>
/// </summary>
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)m_Actuators).GetEnumerator();
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.Add"/>
/// </summary>
/// <param name="item"></param>
public void Add(IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot add to the ActuatorManager after its buffers have been initialized");
m_Actuators.Add(item);
AddToBufferSizes(item);
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.Clear"/>
/// </summary>
public void Clear()
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot clear the ActuatorManager after its buffers have been initialized");
m_Actuators.Clear();
ClearBufferSizes();
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.Contains"/>
/// </summary>
public bool Contains(IActuator item)
{
return m_Actuators.Contains(item);
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.CopyTo"/>
/// </summary>
public void CopyTo(IActuator[] array, int arrayIndex)
{
m_Actuators.CopyTo(array, arrayIndex);
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.Remove"/>
/// </summary>
public bool Remove(IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot remove from the ActuatorManager after its buffers have been initialized");
if (m_Actuators.Remove(item))
{
SubtractFromBufferSize(item);
return true;
}
return false;
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.Count"/>
/// </summary>
public int Count => m_Actuators.Count;
/// <summary>
/// <inheritdoc cref="ICollection{T}.IsReadOnly"/>
/// </summary>
public bool IsReadOnly => m_Actuators.IsReadOnly;
/// <summary>
/// <inheritdoc cref="IList{T}.IndexOf"/>
/// </summary>
public int IndexOf(IActuator item)
{
return m_Actuators.IndexOf(item);
}
/// <summary>
/// <inheritdoc cref="IList{T}.Insert"/>
/// </summary>
public void Insert(int index, IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot insert into the ActuatorManager after its buffers have been initialized");
m_Actuators.Insert(index, item);
AddToBufferSizes(item);
}
/// <summary>
/// <inheritdoc cref="IList{T}.RemoveAt"/>
/// </summary>
public void RemoveAt(int index)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot remove from the ActuatorManager after its buffers have been initialized");
var actuator = m_Actuators[index];
SubtractFromBufferSize(actuator);
m_Actuators.RemoveAt(index);
}
/// <summary>
/// <inheritdoc cref="IList{T}.this"/>
/// </summary>
public IActuator this[int index]
{
get => m_Actuators[index];
set
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot modify the ActuatorManager after its buffers have been initialized");
var old = m_Actuators[index];
SubtractFromBufferSize(old);
m_Actuators[index] = value;
AddToBufferSizes(value);
}
}
}
}

3
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta


fileFormatVersion: 2
guid: 7bb5b1e3779d4342a8e70f6e3c1d67cc
timeCreated: 1593031463

101
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs


using System;
using System.Linq;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// A structure that wraps the <see cref="ActionSegment{T}"/>s for a particular <see cref="IActionReceiver"/> and is
/// used when <see cref="IActionReceiver.OnActionReceived"/> is called.
/// </summary>
internal readonly struct ActionBuffers
{
/// <summary>
/// An empty action buffer.
/// </summary>
public static ActionBuffers Empty = new ActionBuffers(ActionSegment<float>.Empty, ActionSegment<int>.Empty);
/// <summary>
/// Holds the Continuous <see cref="ActionSegment{T}"/> to be used by an <see cref="IActionReceiver"/>.
/// </summary>
public ActionSegment<float> ContinuousActions { get; }
/// <summary>
/// Holds the Discrete <see cref="ActionSegment{T}"/> to be used by an <see cref="IActionReceiver"/>.
/// </summary>
public ActionSegment<int> DiscreteActions { get; }
/// <summary>
/// Construct an <see cref="ActionBuffers"/> instance with the continuous and discrete actions that will
/// be used.
/// </summary>
/// <param name="continuousActions">The continuous actions to send to an <see cref="IActionReceiver"/>.</param>
/// <param name="discreteActions">The discrete actions to send to an <see cref="IActionReceiver"/>.</param>
public ActionBuffers(ActionSegment<float> continuousActions, ActionSegment<int> discreteActions)
{
ContinuousActions = continuousActions;
DiscreteActions = discreteActions;
}
/// <inheritdoc cref="ValueType.Equals(object)"/>
public override bool Equals(object obj)
{
if (!(obj is ActionBuffers))
{
return false;
}
var ab = (ActionBuffers)obj;
return ab.ContinuousActions.SequenceEqual(ContinuousActions) &&
ab.DiscreteActions.SequenceEqual(DiscreteActions);
}
/// <inheritdoc cref="ValueType.GetHashCode"/>
public override int GetHashCode()
{
unchecked
{
return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode();
}
}
}
/// <summary>
/// An interface that describes an object that can receive actions from a Reinforcement Learning network.
/// </summary>
internal interface IActionReceiver
{
/// <summary>
/// The specification of the Action space for this IActionReceiver.
/// </summary>
/// <seealso cref="ActionSpec"/>
ActionSpec ActionSpec { get; }
/// <summary>
/// Method called in order too allow object to execute actions based on the
/// <see cref="ActionBuffers"/> contents. The structure of the contents in the <see cref="ActionBuffers"/>
/// are defined by the <see cref="ActionSpec"/>.
/// </summary>
/// <param name="actionBuffers">The data structure containing the action buffers for this object.</param>
void OnActionReceived(ActionBuffers actionBuffers);
/// <summary>
/// Implement `WriteDiscreteActionMask()` to modify the masks for discrete
/// actions. When using discrete actions, the agent will not perform the masked
/// action.
/// </summary>
/// <param name="actionMask">
/// The action mask for the agent.
/// </param>
/// <remarks>
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask"/>.
///
/// See [Agents - Actions] for more information on masking actions.
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <seealso cref="IActionReceiver.OnActionReceived"/>
void WriteDiscreteActionMask(IDiscreteActionMask actionMask);
}
}

3
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta


fileFormatVersion: 2
guid: b25a5b3027c9476ea1a310241be0f10f
timeCreated: 1594756775

21
com.unity.ml-agents/Runtime/Actuators/IActuator.cs


using System;
using UnityEngine;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Abstraction that facilitates the execution of actions.
/// </summary>
internal interface IActuator : IActionReceiver
{
int TotalNumberOfActions { get; }
/// <summary>
/// Gets the name of this IActuator which will be used to sort it.
/// </summary>
/// <returns></returns>
string Name { get; }
void ResetData();
}
}

3
com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta


fileFormatVersion: 2
guid: 780d7f0a675f44bfa784b370025b51c3
timeCreated: 1592848317

38
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs


using System.Collections.Generic;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Interface for writing a mask to disable discrete actions for agents for the next decision.
/// </summary>
internal interface IDiscreteActionMask
{
/// <summary>
/// Modifies an action mask for discrete control agents.
/// </summary>
/// <remarks>
/// When used, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndices correspond
/// to the action options the agent will be unable to perform.
///
/// See [Agents - Actions] for more information on masking actions.
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_2_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <param name="branch">The branch for which the actions will be masked.</param>
/// <param name="actionIndices">The indices of the masked actions.</param>
void WriteMask(int branch, IEnumerable<int> actionIndices);
/// <summary>
/// Get the current mask for an agent.
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
bool[] GetMask();
/// <summary>
/// Resets the current mask for an agent.
/// </summary>
void ResetMask();
}
}

3
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta


fileFormatVersion: 2
guid: 1bc4e4b71bf4470789488fab2ee65388
timeCreated: 1595369065

72
com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs


using System;
using Unity.MLAgents.Policies;
namespace Unity.MLAgents.Actuators
{
internal class VectorActuator : IActuator
{
IActionReceiver m_ActionReceiver;
ActionBuffers m_ActionBuffers;
internal ActionBuffers ActionBuffers
{
get => m_ActionBuffers;
private set => m_ActionBuffers = value;
}
public VectorActuator(IActionReceiver actionReceiver,
int[] vectorActionSize,
SpaceType spaceType,
string name = "VectorActuator")
{
m_ActionReceiver = actionReceiver;
string suffix;
switch (spaceType)
{
case SpaceType.Continuous:
ActionSpec = ActionSpec.MakeContinuous(vectorActionSize[0]);
suffix = "-Continuous";
break;
case SpaceType.Discrete:
ActionSpec = ActionSpec.MakeDiscrete(vectorActionSize);
suffix = "-Discrete";
break;
default:
throw new ArgumentOutOfRangeException(nameof(spaceType),
spaceType,
"Unknown enum value.");
}
Name = name + suffix;
}
public void ResetData()
{
m_ActionBuffers = ActionBuffers.Empty;
}
public void OnActionReceived(ActionBuffers actionBuffers)
{
ActionBuffers = actionBuffers;
m_ActionReceiver.OnActionReceived(ActionBuffers);
}
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
m_ActionReceiver.WriteDiscreteActionMask(actionMask);
}
/// <summary>
/// Returns the number of discrete branches + the number of continuous actions.
/// </summary>
public int TotalNumberOfActions => ActionSpec.NumContinuousActions +
ActionSpec.NumDiscreteActions;
/// <summary>
/// <inheritdoc cref="IActionReceiver.ActionSpec"/>
/// </summary>
public ActionSpec ActionSpec { get; }
public string Name { get; }
}
}

3
com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta


fileFormatVersion: 2
guid: ff7a3292c0b24b23b3f1c0eeb690ec4c
timeCreated: 1593023833

55
com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs


using System;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
namespace Unity.MLAgents.Tests.Actuators
{
[TestFixture]
public class ActionSegmentTests
{
[Test]
public void TestConstruction()
{
var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f };
Assert.Throws<ArgumentOutOfRangeException>(
() => new ActionSegment<float>(floatArray, 100, 1));
var segment = new ActionSegment<float>(Array.Empty<float>(), 0, 0);
Assert.AreEqual(segment, ActionSegment<float>.Empty);
}
[Test]
public void TestIndexing()
{
var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f };
for (var i = 0; i < floatArray.Length; i++)
{
var start = 0 + i;
var length = floatArray.Length - i;
var actionSegment = new ActionSegment<float>(floatArray, start, length);
for (var j = 0; j < actionSegment.Length; j++)
{
Assert.AreEqual(actionSegment[j], floatArray[start + j]);
}
}
}
[Test]
public void TestEnumerator()
{
var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f };
for (var i = 0; i < floatArray.Length; i++)
{
var start = 0 + i;
var length = floatArray.Length - i;
var actionSegment = new ActionSegment<float>(floatArray, start, length);
var j = 0;
foreach (var item in actionSegment)
{
Assert.AreEqual(item, floatArray[start + j++]);
}
}
}
}
}

3
com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs.meta


fileFormatVersion: 2
guid: 18cb6d052fba43a2b7437d87c0d9abad
timeCreated: 1596486604

114
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs


using System;
using System.Collections.Generic;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
namespace Unity.MLAgents.Tests.Actuators
{
[TestFixture]
public class ActuatorDiscreteActionMaskTests
{
[Test]
public void Construction()
{
var masker = new ActuatorDiscreteActionMask(new List<IActuator>(), 0, 0);
Assert.IsNotNull(masker);
}
[Test]
public void NullMask()
{
var masker = new ActuatorDiscreteActionMask(new List<IActuator>(), 0, 0);
var mask = masker.GetMask();
Assert.IsNull(mask);
}
[Test]
public void FirstBranchMask()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
var mask = masker.GetMask();
Assert.IsNull(mask);
masker.WriteMask(0, new[] {1, 2, 3});
mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsTrue(mask[1]);
Assert.IsTrue(mask[2]);
Assert.IsTrue(mask[3]);
Assert.IsFalse(mask[4]);
Assert.AreEqual(mask.Length, 15);
}
[Test]
public void SecondBranchMask()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
var masker = new ActuatorDiscreteActionMask(new[] {actuator1}, 15, 3);
masker.WriteMask(1, new[] {1, 2, 3});
var mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsFalse(mask[4]);
Assert.IsTrue(mask[5]);
Assert.IsTrue(mask[6]);
Assert.IsTrue(mask[7]);
Assert.IsFalse(mask[8]);
Assert.IsFalse(mask[9]);
}
[Test]
public void MaskReset()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
masker.WriteMask(1, new[] {1, 2, 3});
masker.ResetMask();
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{
Assert.IsFalse(mask[i]);
}
}
[Test]
public void ThrowsError()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
Assert.Catch<UnityAgentsException>(
() => masker.WriteMask(0, new[] {5}));
Assert.Catch<UnityAgentsException>(
() => masker.WriteMask(1, new[] {5}));
masker.WriteMask(2, new[] {5});
Assert.Catch<UnityAgentsException>(
() => masker.WriteMask(3, new[] {1}));
masker.GetMask();
masker.ResetMask();
masker.WriteMask(0, new[] {0, 1, 2, 3});
Assert.Catch<UnityAgentsException>(
() => masker.GetMask());
}
[Test]
public void MultipleMaskEdit()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
masker.WriteMask(0, new[] {0, 1});
masker.WriteMask(0, new[] {3});
masker.WriteMask(2, new[] {1});
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{
if ((i == 0) || (i == 1) || (i == 3) || (i == 10))
{
Assert.IsTrue(mask[i]);
}
else
{
Assert.IsFalse(mask[i]);
}
}
}
}
}

3
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs.meta


fileFormatVersion: 2
guid: b9f5f87049d04d8bba39d193a3ab2f5a
timeCreated: 1596491682

310
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs


using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using UnityEngine;
using UnityEngine.TestTools;
using Assert = UnityEngine.Assertions.Assert;
namespace Unity.MLAgents.Tests.Actuators
{
[TestFixture]
public class ActuatorManagerTests
{
[Test]
public void TestEnsureBufferSizeContinuous()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(10), "actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(2), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var actuator1ActionSpaceDef = actuator1.ActionSpec;
var actuator2ActionSpaceDef = actuator2.ActionSpec;
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 },
actuator1ActionSpaceDef.NumContinuousActions + actuator2ActionSpaceDef.NumContinuousActions,
actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes,
actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions);
manager.UpdateActions(new[]
{ 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f }, Array.Empty<int>());
Assert.IsTrue(12 == manager.NumContinuousActions);
Assert.IsTrue(0 == manager.NumDiscreteActions);
Assert.IsTrue(0 == manager.SumOfDiscreteBranchSizes);
Assert.IsTrue(12 == manager.StoredContinuousActions.Length);
Assert.IsTrue(0 == manager.StoredDiscreteActions.Length);
}
[Test]
public void TestEnsureBufferDiscrete()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 1, 1}), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var actuator1ActionSpaceDef = actuator1.ActionSpec;
var actuator2ActionSpaceDef = actuator2.ActionSpec;
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 },
actuator1ActionSpaceDef.NumContinuousActions + actuator2ActionSpaceDef.NumContinuousActions,
actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes,
actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions);
manager.UpdateActions(Array.Empty<float>(),
new[] { 0, 1, 2, 3, 4, 5, 6});
Assert.IsTrue(0 == manager.NumContinuousActions);
Assert.IsTrue(7 == manager.NumDiscreteActions);
Assert.IsTrue(13 == manager.SumOfDiscreteBranchSizes);
Assert.IsTrue(0 == manager.StoredContinuousActions.Length);
Assert.IsTrue(7 == manager.StoredDiscreteActions.Length);
}
[Test]
public void TestFailOnMixedActionSpace()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, 3, 10, 4);
LogAssert.Expect(LogType.Assert, "Actuators on the same Agent must have the same action SpaceType.");
}
[Test]
public void TestFailOnSameActuatorName()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator1");
manager.Add(actuator1);
manager.Add(actuator2);
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, 3, 10, 4);
LogAssert.Expect(LogType.Assert, "Actuator names must be unique.");
}
[Test]
public void TestExecuteActionsDiscrete()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 1, 1}), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5, 6};
manager.UpdateActions(Array.Empty<float>(),
discreteActionBuffer);
manager.ExecuteActions();
var actuator1Actions = actuator1.LastActionBuffer.DiscreteActions;
var actuator2Actions = actuator2.LastActionBuffer.DiscreteActions;
TestSegmentEquality(actuator1Actions, discreteActionBuffer); TestSegmentEquality(actuator2Actions, discreteActionBuffer);
}
[Test]
public void TestExecuteActionsContinuous()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
manager.UpdateActions(continuousActionBuffer,
Array.Empty<int>());
manager.ExecuteActions();
var actuator1Actions = actuator1.LastActionBuffer.ContinuousActions;
var actuator2Actions = actuator2.LastActionBuffer.ContinuousActions;
TestSegmentEquality(actuator1Actions, continuousActionBuffer);
TestSegmentEquality(actuator2Actions, continuousActionBuffer);
}
static void TestSegmentEquality<T>(ActionSegment<T> actionSegment, T[] actionBuffer)
where T : struct
{
Assert.IsFalse(actionSegment.Length == 0);
for (var i = 0; i < actionSegment.Length; i++)
{
var action = actionSegment[i];
Assert.AreEqual(action, actionBuffer[actionSegment.Offset + i]);
}
}
[Test]
public void TestUpdateActionsContinuous()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
manager.UpdateActions(continuousActionBuffer,
Array.Empty<int>());
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer));
}
[Test]
public void TestUpdateActionsDiscrete()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5};
manager.UpdateActions(Array.Empty<float>(),
discreteActionBuffer);
Debug.Log(manager.StoredDiscreteActions);
Debug.Log(discreteActionBuffer);
Assert.IsTrue(manager.StoredDiscreteActions.SequenceEqual(discreteActionBuffer));
}
[Test]
public void TestRemove()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
Assert.IsTrue(manager.NumDiscreteActions == 6);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 12);
manager.Remove(actuator2);
Assert.IsTrue(manager.NumDiscreteActions == 3);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6);
manager.Remove(null);
Assert.IsTrue(manager.NumDiscreteActions == 3);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6);
manager.RemoveAt(0);
Assert.IsTrue(manager.NumDiscreteActions == 0);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 0);
}
[Test]
public void TestClear()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
Assert.IsTrue(manager.NumDiscreteActions == 6);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 12);
manager.Clear();
Assert.IsTrue(manager.NumDiscreteActions == 0);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 0);
}
[Test]
public void TestIndexSet()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4}),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
manager.Add(actuator1);
Assert.IsTrue(manager.NumDiscreteActions == 4);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 10);
manager[0] = actuator2;
Assert.IsTrue(manager.NumDiscreteActions == 3);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6);
}
[Test]
public void TestInsert()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4}),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
manager.Add(actuator1);
Assert.IsTrue(manager.NumDiscreteActions == 4);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 10);
manager.Insert(0, actuator2);
Assert.IsTrue(manager.NumDiscreteActions == 7);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 16);
Assert.IsTrue(manager.IndexOf(actuator2) == 0);
}
[Test]
public void TestResetData()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
manager.UpdateActions(continuousActionBuffer,
Array.Empty<int>());
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer));
Assert.IsTrue(manager.NumContinuousActions == 6);
manager.ResetData();
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(new[] { 0f, 0f, 0f, 0f, 0f, 0f}));
}
[Test]
public void TestWriteDiscreteActionMask()
{
var manager = new ActuatorManager(2);
var va1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "name");
var va2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {3, 2, 1}), "name1");
manager.Add(va1);
manager.Add(va2);
var groundTruthMask = new[]
{
false,
true, false,
false, true, true,
true, false, true,
false, true,
false
};
va1.Masks = new[]
{
Array.Empty<int>(),
new[] { 0 },
new[] { 1, 2 }
};
va2.Masks = new[]
{
new[] {0, 2},
new[] {1},
Array.Empty<int>()
};
manager.WriteActionMask();
Assert.IsTrue(groundTruthMask.SequenceEqual(manager.DiscreteActionMask.GetMask()));
}
}
}

3
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs.meta


fileFormatVersion: 2
guid: d48ba72f0ac64d7db0af22c9d82b11d8
timeCreated: 1596494279

38
com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs


using Unity.MLAgents.Actuators;
namespace Unity.MLAgents.Tests.Actuators
{
internal class TestActuator : IActuator
{
public ActionBuffers LastActionBuffer;
public int[][] Masks;
public TestActuator(ActionSpec actuatorSpace, string name)
{
ActionSpec = actuatorSpace;
TotalNumberOfActions = actuatorSpace.NumContinuousActions +
actuatorSpace.NumDiscreteActions;
Name = name;
}
public void OnActionReceived(ActionBuffers actionBuffers)
{
LastActionBuffer = actionBuffers;
}
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
for (var i = 0; i < Masks.Length; i++)
{
actionMask.WriteMask(i, Masks[i]);
}
}
public int TotalNumberOfActions { get; }
public ActionSpec ActionSpec { get; }
public string Name { get; }
public void ResetData()
{
}
}
}

3
com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs.meta


fileFormatVersion: 2
guid: fa950d7b175749bfa287fd8761dd831f
timeCreated: 1596665978

98
com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs


using System.Collections.Generic;
using System.Linq;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using Assert = UnityEngine.Assertions.Assert;
namespace Unity.MLAgents.Tests.Actuators
{
[TestFixture]
public class VectorActuatorTests
{
class TestActionReceiver : IActionReceiver
{
public ActionBuffers LastActionBuffers;
public int Branch;
public IList<int> Mask;
public ActionSpec ActionSpec { get; }
public void OnActionReceived(ActionBuffers actionBuffers)
{
LastActionBuffers = actionBuffers;
}
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
actionMask.WriteMask(Branch, Mask);
}
}
[Test]
public void TestConstruct()
{
var ar = new TestActionReceiver();
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
Assert.IsTrue(va.ActionSpec.NumDiscreteActions == 3);
Assert.IsTrue(va.ActionSpec.SumOfDiscreteBranchSizes == 6);
Assert.IsTrue(va.ActionSpec.NumContinuousActions == 0);
var va1 = new VectorActuator(ar, new[] {4}, SpaceType.Continuous, "name");
Assert.IsTrue(va1.ActionSpec.NumContinuousActions == 4);
Assert.IsTrue(va1.ActionSpec.SumOfDiscreteBranchSizes == 0);
Assert.AreEqual(va1.Name, "name-Continuous");
}
[Test]
public void TestOnActionReceived()
{
var ar = new TestActionReceiver();
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
var discreteActions = new[] { 0, 1, 1 };
var ab = new ActionBuffers(ActionSegment<float>.Empty,
new ActionSegment<int>(discreteActions, 0, 3));
va.OnActionReceived(ab);
Assert.AreEqual(ar.LastActionBuffers, ab);
va.ResetData();
Assert.AreEqual(va.ActionBuffers.ContinuousActions, ActionSegment<float>.Empty);
Assert.AreEqual(va.ActionBuffers.DiscreteActions, ActionSegment<int>.Empty);
}
[Test]
public void TestResetData()
{
var ar = new TestActionReceiver();
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
var discreteActions = new[] { 0, 1, 1 };
var ab = new ActionBuffers(ActionSegment<float>.Empty,
new ActionSegment<int>(discreteActions, 0, 3));
va.OnActionReceived(ab);
}
[Test]
public void TestWriteDiscreteActionMask()
{
var ar = new TestActionReceiver();
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
var bdam = new ActuatorDiscreteActionMask(new[] {va}, 6, 3);
var groundTruthMask = new[] { false, true, false, false, true, true };
ar.Branch = 1;
ar.Mask = new[] { 0 };
va.WriteDiscreteActionMask(bdam);
ar.Branch = 2;
ar.Mask = new[] { 1, 2 };
va.WriteDiscreteActionMask(bdam);
Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask()));
}
}
}

3
com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs.meta


fileFormatVersion: 2
guid: c2b191d2929f49adab0769705d49d86a
timeCreated: 1596580289
正在加载...
取消
保存