HH
4 年前
当前提交
d4bd7fe6
共有 48 个文件被更改,包括 1981 次插入 和 116 次删除
-
2Project/ProjectSettings/ProjectVersion.txt
-
2com.unity.ml-agents.extensions/package.json
-
21com.unity.ml-agents/CHANGELOG.md
-
2com.unity.ml-agents/Runtime/Academy.cs
-
2com.unity.ml-agents/package.json
-
2docs/Training-ML-Agents.md
-
22docs/Unity-Inference-Engine.md
-
8docs/Using-Tensorboard.md
-
2gym-unity/gym_unity/__init__.py
-
2ml-agents-envs/mlagents_envs/__init__.py
-
18ml-agents/mlagents/model_serialization.py
-
2ml-agents/mlagents/trainers/__init__.py
-
59ml-agents/mlagents/trainers/stats.py
-
42ml-agents/mlagents/trainers/tests/test_stats.py
-
8ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
8com.unity.ml-agents/Runtime/Actuators.meta
-
8com.unity.ml-agents/Tests/Editor/Actuators.meta
-
160docs/images/TensorBoard-download.png
-
181com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
-
3com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta
-
75com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
-
3com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta
-
17com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs
-
3com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta
-
150com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs
-
3com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta
-
415com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
-
3com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta
-
101com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
-
3com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta
-
21com.unity.ml-agents/Runtime/Actuators/IActuator.cs
-
3com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta
-
38com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
-
3com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta
-
72com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
-
3com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta
-
55com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs
-
3com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs.meta
-
114com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs
-
3com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs.meta
-
310com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
-
3com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs.meta
-
38com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
-
3com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs.meta
-
98com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs
-
3com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs.meta
|
|||
m_EditorVersion: 2018.4.18f1 |
|||
m_EditorVersion: 2018.4.24f1 |
|
|||
# Version of the library that will be used to upload to pypi |
|||
__version__ = "0.19.0.dev0" |
|||
__version__ = "0.20.0.dev0" |
|||
|
|||
# Git tag that will be checked to determine whether to trigger upload to pypi |
|||
__release_tag__ = None |
|
|||
# Version of the library that will be used to upload to pypi |
|||
__version__ = "0.19.0.dev0" |
|||
__version__ = "0.20.0.dev0" |
|||
|
|||
# Git tag that will be checked to determine whether to trigger upload to pypi |
|||
__release_tag__ = None |
|
|||
# Version of the library that will be used to upload to pypi |
|||
__version__ = "0.19.0.dev0" |
|||
__version__ = "0.20.0.dev0" |
|||
|
|||
# Git tag that will be checked to determine whether to trigger upload to pypi |
|||
__release_tag__ = None |
|
|||
fileFormatVersion: 2 |
|||
guid: 26733e59183b6479e8f0e892a8bf09a4 |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: c7e705f7d549e43c6be18ae809cd6f54 |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using System.Collections; |
|||
using System.Collections.Generic; |
|||
using System.Diagnostics; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// ActionSegment{T} is a data structure that allows access to a segment of an underlying array
|
|||
/// in order to avoid the copying and allocation of sub-arrays. The segment is defined by
|
|||
/// the offset into the original array, and an length.
|
|||
/// </summary>
|
|||
/// <typeparam name="T">The type of object stored in the underlying <see cref="Array"/></typeparam>
|
|||
internal readonly struct ActionSegment<T> : IEnumerable<T>, IEquatable<ActionSegment<T>> |
|||
where T : struct |
|||
{ |
|||
/// <summary>
|
|||
/// The zero-based offset into the original array at which this segment starts.
|
|||
/// </summary>
|
|||
public readonly int Offset; |
|||
|
|||
/// <summary>
|
|||
/// The number of items this segment can access in the underlying array.
|
|||
/// </summary>
|
|||
public readonly int Length; |
|||
|
|||
/// <summary>
|
|||
/// An Empty segment which has an offset of 0, a Length of 0, and it's underlying array
|
|||
/// is also empty.
|
|||
/// </summary>
|
|||
public static ActionSegment<T> Empty = new ActionSegment<T>(System.Array.Empty<T>(), 0, 0); |
|||
|
|||
static void CheckParameters(T[] actionArray, int offset, int length) |
|||
{ |
|||
#if DEBUG
|
|||
if (offset + length > actionArray.Length) |
|||
{ |
|||
throw new ArgumentOutOfRangeException(nameof(offset), |
|||
$"Arguments offset: {offset} and length: {length} " + |
|||
$"are out of bounds of actionArray: {actionArray.Length}."); |
|||
} |
|||
#endif
|
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Construct an <see cref="ActionSegment{T}"/> with an underlying array
|
|||
/// and offset, and a length.
|
|||
/// </summary>
|
|||
/// <param name="actionArray">The underlying array which this segment has a view into</param>
|
|||
/// <param name="offset">The zero-based offset into the underlying array.</param>
|
|||
/// <param name="length">The length of the segment.</param>
|
|||
public ActionSegment(T[] actionArray, int offset, int length) |
|||
{ |
|||
CheckParameters(actionArray, offset, length); |
|||
Array = actionArray; |
|||
Offset = offset; |
|||
Length = length; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Get the underlying <see cref="Array"/> of this segment.
|
|||
/// </summary>
|
|||
public T[] Array { get; } |
|||
|
|||
/// <summary>
|
|||
/// Allows access to the underlying array using array syntax.
|
|||
/// </summary>
|
|||
/// <param name="index">The zero-based index of the segment.</param>
|
|||
/// <exception cref="IndexOutOfRangeException">Thrown when the index is less than 0 or
|
|||
/// greater than or equal to <see cref="Length"/></exception>
|
|||
public T this[int index] |
|||
{ |
|||
get |
|||
{ |
|||
if (index < 0 || index > Length) |
|||
{ |
|||
throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}"); |
|||
} |
|||
return Array[Offset + index]; |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc cref="IEnumerable{T}.GetEnumerator"/>
|
|||
IEnumerator<T> IEnumerable<T>.GetEnumerator() |
|||
{ |
|||
return new Enumerator(this); |
|||
} |
|||
|
|||
/// <inheritdoc cref="IEnumerable{T}"/>
|
|||
public IEnumerator GetEnumerator() |
|||
{ |
|||
return new Enumerator(this); |
|||
} |
|||
|
|||
/// <inheritdoc cref="ValueType.Equals(object)"/>
|
|||
public override bool Equals(object obj) |
|||
{ |
|||
if (!(obj is ActionSegment<T>)) |
|||
{ |
|||
return false; |
|||
} |
|||
return Equals((ActionSegment<T>)obj); |
|||
} |
|||
|
|||
/// <inheritdoc cref="IEquatable{T}.Equals(T)"/>
|
|||
public bool Equals(ActionSegment<T> other) |
|||
{ |
|||
return Offset == other.Offset && Length == other.Length && Equals(Array, other.Array); |
|||
} |
|||
|
|||
/// <inheritdoc cref="ValueType.GetHashCode"/>
|
|||
public override int GetHashCode() |
|||
{ |
|||
unchecked |
|||
{ |
|||
var hashCode = Offset; |
|||
hashCode = (hashCode * 397) ^ Length; |
|||
hashCode = (hashCode * 397) ^ (Array != null ? Array.GetHashCode() : 0); |
|||
return hashCode; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// A private <see cref="IEnumerator{T}"/> for the <see cref="ActionSegment{T}"/> value type which follows its
|
|||
/// rules of being a view into an underlying <see cref="Array"/>.
|
|||
/// </summary>
|
|||
struct Enumerator : IEnumerator<T> |
|||
{ |
|||
readonly T[] m_Array; |
|||
readonly int m_Start; |
|||
readonly int m_End; // cache Offset + Count, since it's a little slow
|
|||
int m_Current; |
|||
|
|||
internal Enumerator(ActionSegment<T> arraySegment) |
|||
{ |
|||
Debug.Assert(arraySegment.Array != null); |
|||
Debug.Assert(arraySegment.Offset >= 0); |
|||
Debug.Assert(arraySegment.Length >= 0); |
|||
Debug.Assert(arraySegment.Offset + arraySegment.Length <= arraySegment.Array.Length); |
|||
|
|||
m_Array = arraySegment.Array; |
|||
m_Start = arraySegment.Offset; |
|||
m_End = arraySegment.Offset + arraySegment.Length; |
|||
m_Current = arraySegment.Offset - 1; |
|||
} |
|||
|
|||
public bool MoveNext() |
|||
{ |
|||
if (m_Current < m_End) |
|||
{ |
|||
m_Current++; |
|||
return m_Current < m_End; |
|||
} |
|||
return false; |
|||
} |
|||
|
|||
public T Current |
|||
{ |
|||
get |
|||
{ |
|||
if (m_Current < m_Start) |
|||
throw new InvalidOperationException("Enumerator not started."); |
|||
if (m_Current >= m_End) |
|||
throw new InvalidOperationException("Enumerator has reached the end already."); |
|||
return m_Array[m_Current]; |
|||
} |
|||
} |
|||
|
|||
object IEnumerator.Current => Current; |
|||
|
|||
void IEnumerator.Reset() |
|||
{ |
|||
m_Current = m_Start - 1; |
|||
} |
|||
|
|||
public void Dispose() |
|||
{ |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 4fa1432c1ba3460caaa84303a9011ef2 |
|||
timeCreated: 1595869823 |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using System.Linq; |
|||
using Unity.MLAgents.Policies; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// Defines the structure of an Action Space to be used by the Actuator system.
|
|||
/// </summary>
|
|||
internal readonly struct ActionSpec |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// An array of branch sizes for our action space.
|
|||
///
|
|||
/// For an IActuator that uses a Discrete <see cref="SpaceType"/>, the number of
|
|||
/// branches is the Length of the Array and each index contains the branch size.
|
|||
/// The cumulative sum of the total number of discrete actions can be retrieved
|
|||
/// by the <see cref="SumOfDiscreteBranchSizes"/> property.
|
|||
///
|
|||
/// For an IActuator with a Continuous it will be null.
|
|||
/// </summary>
|
|||
public readonly int[] BranchSizes; |
|||
|
|||
/// <summary>
|
|||
/// The number of actions for a Continuous <see cref="SpaceType"/>.
|
|||
/// </summary>
|
|||
public int NumContinuousActions { get; } |
|||
|
|||
/// <summary>
|
|||
/// The number of branches for a Discrete <see cref="SpaceType"/>.
|
|||
/// </summary>
|
|||
public int NumDiscreteActions { get; } |
|||
|
|||
/// <summary>
|
|||
/// Get the total number of Discrete Actions that can be taken by calculating the Sum
|
|||
/// of all of the Discrete Action branch sizes.
|
|||
/// </summary>
|
|||
public int SumOfDiscreteBranchSizes { get; } |
|||
|
|||
/// <summary>
|
|||
/// Creates a Continuous <see cref="ActionSpec"/> with the number of actions available.
|
|||
/// </summary>
|
|||
/// <param name="numActions">The number of actions available.</param>
|
|||
/// <returns>An Continuous ActionSpec initialized with the number of actions available.</returns>
|
|||
public static ActionSpec MakeContinuous(int numActions) |
|||
{ |
|||
var actuatorSpace = new ActionSpec(numActions, 0); |
|||
return actuatorSpace; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Creates a Discrete <see cref="ActionSpec"/> with the array of branch sizes that
|
|||
/// represents the action space.
|
|||
/// </summary>
|
|||
/// <param name="branchSizes">The array of branch sizes for the discrete action space. Each index
|
|||
/// contains the number of actions available for that branch.</param>
|
|||
/// <returns>An Discrete ActionSpec initialized with the array of branch sizes.</returns>
|
|||
public static ActionSpec MakeDiscrete(int[] branchSizes) |
|||
{ |
|||
var numActions = branchSizes.Length; |
|||
var actuatorSpace = new ActionSpec(0, numActions, branchSizes); |
|||
return actuatorSpace; |
|||
} |
|||
|
|||
ActionSpec(int numContinuousActions, int numDiscreteActions, int[] branchSizes = null) |
|||
{ |
|||
NumContinuousActions = numContinuousActions; |
|||
NumDiscreteActions = numDiscreteActions; |
|||
BranchSizes = branchSizes; |
|||
SumOfDiscreteBranchSizes = branchSizes?.Sum() ?? 0; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: ecdd6deefba1416ca149fe09d2a5afd8 |
|||
timeCreated: 1595892361 |
|
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// Editor components for creating Actuators. Generally an IActuator component should
|
|||
/// have a corresponding ActuatorComponent.
|
|||
/// </summary>
|
|||
internal abstract class ActuatorComponent : MonoBehaviour |
|||
{ |
|||
/// <summary>
|
|||
/// Create the IActuator. This is called by the Agent when it is initialized.
|
|||
/// </summary>
|
|||
/// <returns>Created IActuator object.</returns>
|
|||
public abstract IActuator CreateActuator(); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 77cefae5f6d841be9ff80b41293d271b |
|||
timeCreated: 1593017318 |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using System.Linq; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// Implementation of IDiscreteActionMask that allows writing to the action mask from an <see cref="IActuator"/>.
|
|||
/// </summary>
|
|||
internal class ActuatorDiscreteActionMask : IDiscreteActionMask |
|||
{ |
|||
/// When using discrete control, is the starting indices of the actions
|
|||
/// when all the branches are concatenated with each other.
|
|||
int[] m_StartingActionIndices; |
|||
|
|||
int[] m_BranchSizes; |
|||
|
|||
bool[] m_CurrentMask; |
|||
|
|||
IList<IActuator> m_Actuators; |
|||
|
|||
readonly int m_SumOfDiscreteBranchSizes; |
|||
readonly int m_NumBranches; |
|||
|
|||
/// <summary>
|
|||
/// The offset into the branches array that is used when actuators are writing to the action mask.
|
|||
/// </summary>
|
|||
public int CurrentBranchOffset { get; set; } |
|||
|
|||
internal ActuatorDiscreteActionMask(IList<IActuator> actuators, int sumOfDiscreteBranchSizes, int numBranches) |
|||
{ |
|||
m_Actuators = actuators; |
|||
m_SumOfDiscreteBranchSizes = sumOfDiscreteBranchSizes; |
|||
m_NumBranches = numBranches; |
|||
} |
|||
|
|||
/// <inheritdoc cref="IDiscreteActionMask.WriteMask"/>
|
|||
public void WriteMask(int branch, IEnumerable<int> actionIndices) |
|||
{ |
|||
LazyInitialize(); |
|||
|
|||
// Perform the masking
|
|||
foreach (var actionIndex in actionIndices) |
|||
{ |
|||
#if DEBUG
|
|||
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch]) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking: Action Mask is too large for specified branch."); |
|||
} |
|||
#endif
|
|||
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = true; |
|||
} |
|||
} |
|||
|
|||
void LazyInitialize() |
|||
{ |
|||
if (m_BranchSizes == null) |
|||
{ |
|||
m_BranchSizes = new int[m_NumBranches]; |
|||
var start = 0; |
|||
for (var i = 0; i < m_Actuators.Count; i++) |
|||
{ |
|||
var actuator = m_Actuators[i]; |
|||
var branchSizes = actuator.ActionSpec.BranchSizes; |
|||
Array.Copy(branchSizes, 0, m_BranchSizes, start, branchSizes.Length); |
|||
start += branchSizes.Length; |
|||
} |
|||
} |
|||
|
|||
// By default, the masks are null. If we want to specify a new mask, we initialize
|
|||
// the actionMasks with trues.
|
|||
if (m_CurrentMask == null) |
|||
{ |
|||
m_CurrentMask = new bool[m_SumOfDiscreteBranchSizes]; |
|||
} |
|||
|
|||
// If this is the first time the masked actions are used, we generate the starting
|
|||
// indices for each branch.
|
|||
if (m_StartingActionIndices == null) |
|||
{ |
|||
m_StartingActionIndices = Utilities.CumSum(m_BranchSizes); |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc cref="IDiscreteActionMask.GetMask"/>
|
|||
public bool[] GetMask() |
|||
{ |
|||
#if DEBUG
|
|||
if (m_CurrentMask != null) |
|||
{ |
|||
AssertMask(); |
|||
} |
|||
#endif
|
|||
return m_CurrentMask; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Makes sure that the current mask is usable.
|
|||
/// </summary>
|
|||
void AssertMask() |
|||
{ |
|||
#if DEBUG
|
|||
for (var branchIndex = 0; branchIndex < m_NumBranches; branchIndex++) |
|||
{ |
|||
if (AreAllActionsMasked(branchIndex)) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking : All the actions of branch " + branchIndex + |
|||
" are masked."); |
|||
} |
|||
} |
|||
#endif
|
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Resets the current mask for an agent.
|
|||
/// </summary>
|
|||
public void ResetMask() |
|||
{ |
|||
if (m_CurrentMask != null) |
|||
{ |
|||
Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Checks if all the actions in the input branch are masked.
|
|||
/// </summary>
|
|||
/// <param name="branch"> The index of the branch to check.</param>
|
|||
/// <returns> True if all the actions of the branch are masked.</returns>
|
|||
bool AreAllActionsMasked(int branch) |
|||
{ |
|||
if (m_CurrentMask == null) |
|||
{ |
|||
return false; |
|||
} |
|||
var start = m_StartingActionIndices[branch]; |
|||
var end = m_StartingActionIndices[branch + 1]; |
|||
for (var i = start; i < end; i++) |
|||
{ |
|||
if (!m_CurrentMask[i]) |
|||
{ |
|||
return false; |
|||
} |
|||
} |
|||
return true; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: d2a19e2f43fd4637a38d42b2a5f989f3 |
|||
timeCreated: 1595459316 |
|
|||
using System; |
|||
using System.Collections; |
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// A class that manages the delegation of events, action buffers, and action mask for a list of IActuators.
|
|||
/// </summary>
|
|||
internal class ActuatorManager : IList<IActuator> |
|||
{ |
|||
// IActuators managed by this object.
|
|||
IList<IActuator> m_Actuators; |
|||
|
|||
// An implementation of IDiscreteActionMask that allows for writing to it based on an offset.
|
|||
ActuatorDiscreteActionMask m_DiscreteActionMask; |
|||
|
|||
/// <summary>
|
|||
/// Flag used to check if our IActuators are ready for execution.
|
|||
/// </summary>
|
|||
/// <seealso cref="ReadyActuatorsForExecution(IList{IActuator}, int, int, int)"/>
|
|||
bool m_ReadyForExecution; |
|||
|
|||
/// <summary>
|
|||
/// The sum of all of the discrete branches for all of the <see cref="IActuator"/>s in this manager.
|
|||
/// </summary>
|
|||
internal int SumOfDiscreteBranchSizes { get; private set; } |
|||
|
|||
/// <summary>
|
|||
/// The number of the discrete branches for all of the <see cref="IActuator"/>s in this manager.
|
|||
/// </summary>
|
|||
internal int NumDiscreteActions { get; private set; } |
|||
|
|||
/// <summary>
|
|||
/// The number of continuous actions for all of the <see cref="IActuator"/>s in this manager.
|
|||
/// </summary>
|
|||
internal int NumContinuousActions { get; private set; } |
|||
|
|||
/// <summary>
|
|||
/// Returns the total actions which is calculated by <see cref="NumContinuousActions"/> + <see cref="NumDiscreteActions"/>.
|
|||
/// </summary>
|
|||
public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions; |
|||
|
|||
/// <summary>
|
|||
/// Gets the <see cref="IDiscreteActionMask"/> managed by this object.
|
|||
/// </summary>
|
|||
public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask; |
|||
|
|||
/// <summary>
|
|||
/// Returns the previously stored actions for the actuators in this list.
|
|||
/// </summary>
|
|||
public float[] StoredContinuousActions { get; private set; } |
|||
|
|||
/// <summary>
|
|||
/// Returns the previously stored actions for the actuators in this list.
|
|||
/// </summary>
|
|||
public int[] StoredDiscreteActions { get; private set; } |
|||
|
|||
/// <summary>
|
|||
/// Create an ActuatorList with a preset capacity.
|
|||
/// </summary>
|
|||
/// <param name="capacity">The capacity of the list to create.</param>
|
|||
public ActuatorManager(int capacity = 0) |
|||
{ |
|||
m_Actuators = new List<IActuator>(capacity); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <see cref="ReadyActuatorsForExecution(IList{IActuator}, int, int, int)"/>
|
|||
/// </summary>
|
|||
void ReadyActuatorsForExecution() |
|||
{ |
|||
ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes, |
|||
NumDiscreteActions); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// This method validates that all <see cref="IActuator"/>s have unique names and equivalent action space types
|
|||
/// if the `DEBUG` preprocessor macro is defined, and allocates the appropriate buffers to manage the actions for
|
|||
/// all of the <see cref="IActuator"/>s that may live on a particular object.
|
|||
/// </summary>
|
|||
/// <param name="actuators">The list of actuators to validate and allocate buffers for.</param>
|
|||
/// <param name="numContinuousActions">The total number of continuous actions for all of the actuators.</param>
|
|||
/// <param name="sumOfDiscreteBranches">The total sum of the discrete branches for all of the actuators in order
|
|||
/// to be able to allocate an <see cref="IDiscreteActionMask"/>.</param>
|
|||
/// <param name="numDiscreteBranches">The number of discrete branches for all of the actuators.</param>
|
|||
internal void ReadyActuatorsForExecution(IList<IActuator> actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches) |
|||
{ |
|||
if (m_ReadyForExecution) |
|||
{ |
|||
return; |
|||
} |
|||
#if DEBUG
|
|||
// Make sure the names are actually unique
|
|||
// Make sure all Actuators have the same SpaceType
|
|||
ValidateActuators(); |
|||
#endif
|
|||
|
|||
// Sort the Actuators by name to ensure determinism
|
|||
SortActuators(); |
|||
StoredContinuousActions = numContinuousActions == 0 ? Array.Empty<float>() : new float[numContinuousActions]; |
|||
StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty<int>() : new int[numDiscreteBranches]; |
|||
m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches); |
|||
m_ReadyForExecution = true; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Updates the local action buffer with the action buffer passed in. If the buffer
|
|||
/// passed in is null, the local action buffer will be cleared.
|
|||
/// </summary>
|
|||
/// <param name="continuousActionBuffer">The action buffer which contains all of the
|
|||
/// continuous actions for the IActuators in this list.</param>
|
|||
/// <param name="discreteActionBuffer">The action buffer which contains all of the
|
|||
/// discrete actions for the IActuators in this list.</param>
|
|||
public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer) |
|||
{ |
|||
ReadyActuatorsForExecution(); |
|||
UpdateActionArray(continuousActionBuffer, StoredContinuousActions); |
|||
UpdateActionArray(discreteActionBuffer, StoredDiscreteActions); |
|||
} |
|||
|
|||
static void UpdateActionArray<T>(T[] sourceActionBuffer, T[] destination) |
|||
{ |
|||
if (sourceActionBuffer == null || sourceActionBuffer.Length == 0) |
|||
{ |
|||
Array.Clear(destination, 0, destination.Length); |
|||
} |
|||
else |
|||
{ |
|||
Debug.Assert(sourceActionBuffer.Length == destination.Length, |
|||
$"sourceActionBuffer:{sourceActionBuffer.Length} is a different" + |
|||
$" size than destination: {destination.Length}."); |
|||
|
|||
Array.Copy(sourceActionBuffer, destination, destination.Length); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// This method will trigger the writing to the <see cref="IDiscreteActionMask"/> by all of the actuators
|
|||
/// managed by this object.
|
|||
/// </summary>
|
|||
public void WriteActionMask() |
|||
{ |
|||
ReadyActuatorsForExecution(); |
|||
m_DiscreteActionMask.ResetMask(); |
|||
var offset = 0; |
|||
for (var i = 0; i < m_Actuators.Count; i++) |
|||
{ |
|||
var actuator = m_Actuators[i]; |
|||
m_DiscreteActionMask.CurrentBranchOffset = offset; |
|||
actuator.WriteDiscreteActionMask(m_DiscreteActionMask); |
|||
offset += actuator.ActionSpec.NumDiscreteActions; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Iterates through all of the IActuators in this list and calls their
|
|||
/// <see cref="IActionReceiver.OnActionReceived"/> method on them with the appropriate
|
|||
/// <see cref="ActionSegment{T}"/>s depending on their <see cref="IActionReceiver.ActionSpec"/>.
|
|||
/// </summary>
|
|||
public void ExecuteActions() |
|||
{ |
|||
ReadyActuatorsForExecution(); |
|||
var continuousStart = 0; |
|||
var discreteStart = 0; |
|||
for (var i = 0; i < m_Actuators.Count; i++) |
|||
{ |
|||
var actuator = m_Actuators[i]; |
|||
var numContinuousActions = actuator.ActionSpec.NumContinuousActions; |
|||
var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; |
|||
|
|||
var continuousActions = ActionSegment<float>.Empty; |
|||
if (numContinuousActions > 0) |
|||
{ |
|||
continuousActions = new ActionSegment<float>(StoredContinuousActions, |
|||
continuousStart, |
|||
numContinuousActions); |
|||
} |
|||
|
|||
var discreteActions = ActionSegment<int>.Empty; |
|||
if (numDiscreteActions > 0) |
|||
{ |
|||
discreteActions = new ActionSegment<int>(StoredDiscreteActions, |
|||
discreteStart, |
|||
numDiscreteActions); |
|||
} |
|||
|
|||
actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions)); |
|||
continuousStart += numContinuousActions; |
|||
discreteStart += numDiscreteActions; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Resets the <see cref="StoredContinuousActions"/> and <see cref="StoredDiscreteActions"/> buffers to be all
|
|||
/// zeros and calls <see cref="IActuator.ResetData"/> on each <see cref="IActuator"/> managed by this object.
|
|||
/// </summary>
|
|||
public void ResetData() |
|||
{ |
|||
if (!m_ReadyForExecution) |
|||
{ |
|||
return; |
|||
} |
|||
Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length); |
|||
Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length); |
|||
for (var i = 0; i < m_Actuators.Count; i++) |
|||
{ |
|||
m_Actuators[i].ResetData(); |
|||
} |
|||
} |
|||
|
|||
|
|||
/// <summary>
|
|||
/// Sorts the <see cref="IActuator"/>s according to their <see cref="IActuator.GetName"/> value.
|
|||
/// </summary>
|
|||
void SortActuators() |
|||
{ |
|||
((List<IActuator>)m_Actuators).Sort((x, |
|||
y) => x.Name |
|||
.CompareTo(y.Name)); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Validates that the IActuators managed by this object have unique names and equivalent action space types.
|
|||
/// Each Actuator needs to have a unique name in order for this object to ensure that the storage of action
|
|||
/// buffers, and execution of Actuators remains deterministic across different sessions of running.
|
|||
/// </summary>
|
|||
void ValidateActuators() |
|||
{ |
|||
for (var i = 0; i < m_Actuators.Count - 1; i++) |
|||
{ |
|||
Debug.Assert( |
|||
!m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name), |
|||
"Actuator names must be unique."); |
|||
var first = m_Actuators[i].ActionSpec; |
|||
var second = m_Actuators[i + 1].ActionSpec; |
|||
Debug.Assert(first.NumContinuousActions > 0 == second.NumContinuousActions > 0, |
|||
"Actuators on the same Agent must have the same action SpaceType."); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Helper method to update bookkeeping items around buffer management for actuators added to this object.
|
|||
/// </summary>
|
|||
/// <param name="actuatorItem">The IActuator to keep bookkeeping for.</param>
|
|||
void AddToBufferSizes(IActuator actuatorItem) |
|||
{ |
|||
if (actuatorItem == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions; |
|||
NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions; |
|||
SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Helper method to update bookkeeping items around buffer management for actuators removed from this object.
|
|||
/// </summary>
|
|||
/// <param name="actuatorItem">The IActuator to keep bookkeeping for.</param>
|
|||
void SubtractFromBufferSize(IActuator actuatorItem) |
|||
{ |
|||
if (actuatorItem == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions; |
|||
NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions; |
|||
SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Sets all of the bookkeeping items back to 0.
|
|||
/// </summary>
|
|||
void ClearBufferSizes() |
|||
{ |
|||
NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0; |
|||
} |
|||
|
|||
/********************************************************************************* |
|||
* IList implementation that delegates to m_Actuators List. * |
|||
*********************************************************************************/ |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IEnumerable{T}.GetEnumerator"/>
|
|||
/// </summary>
|
|||
public IEnumerator<IActuator> GetEnumerator() |
|||
{ |
|||
return m_Actuators.GetEnumerator(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IList{T}.GetEnumerator"/>
|
|||
/// </summary>
|
|||
IEnumerator IEnumerable.GetEnumerator() |
|||
{ |
|||
return ((IEnumerable)m_Actuators).GetEnumerator(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.Add"/>
|
|||
/// </summary>
|
|||
/// <param name="item"></param>
|
|||
public void Add(IActuator item) |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot add to the ActuatorManager after its buffers have been initialized"); |
|||
m_Actuators.Add(item); |
|||
AddToBufferSizes(item); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.Clear"/>
|
|||
/// </summary>
|
|||
public void Clear() |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot clear the ActuatorManager after its buffers have been initialized"); |
|||
m_Actuators.Clear(); |
|||
ClearBufferSizes(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.Contains"/>
|
|||
/// </summary>
|
|||
public bool Contains(IActuator item) |
|||
{ |
|||
return m_Actuators.Contains(item); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.CopyTo"/>
|
|||
/// </summary>
|
|||
public void CopyTo(IActuator[] array, int arrayIndex) |
|||
{ |
|||
m_Actuators.CopyTo(array, arrayIndex); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.Remove"/>
|
|||
/// </summary>
|
|||
public bool Remove(IActuator item) |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot remove from the ActuatorManager after its buffers have been initialized"); |
|||
if (m_Actuators.Remove(item)) |
|||
{ |
|||
SubtractFromBufferSize(item); |
|||
return true; |
|||
} |
|||
return false; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.Count"/>
|
|||
/// </summary>
|
|||
public int Count => m_Actuators.Count; |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.IsReadOnly"/>
|
|||
/// </summary>
|
|||
public bool IsReadOnly => m_Actuators.IsReadOnly; |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IList{T}.IndexOf"/>
|
|||
/// </summary>
|
|||
public int IndexOf(IActuator item) |
|||
{ |
|||
return m_Actuators.IndexOf(item); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IList{T}.Insert"/>
|
|||
/// </summary>
|
|||
public void Insert(int index, IActuator item) |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot insert into the ActuatorManager after its buffers have been initialized"); |
|||
m_Actuators.Insert(index, item); |
|||
AddToBufferSizes(item); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IList{T}.RemoveAt"/>
|
|||
/// </summary>
|
|||
public void RemoveAt(int index) |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot remove from the ActuatorManager after its buffers have been initialized"); |
|||
var actuator = m_Actuators[index]; |
|||
SubtractFromBufferSize(actuator); |
|||
m_Actuators.RemoveAt(index); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IList{T}.this"/>
|
|||
/// </summary>
|
|||
public IActuator this[int index] |
|||
{ |
|||
get => m_Actuators[index]; |
|||
set |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot modify the ActuatorManager after its buffers have been initialized"); |
|||
var old = m_Actuators[index]; |
|||
SubtractFromBufferSize(old); |
|||
m_Actuators[index] = value; |
|||
AddToBufferSizes(value); |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 7bb5b1e3779d4342a8e70f6e3c1d67cc |
|||
timeCreated: 1593031463 |
|
|||
using System; |
|||
using System.Linq; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// A structure that wraps the <see cref="ActionSegment{T}"/>s for a particular <see cref="IActionReceiver"/> and is
|
|||
/// used when <see cref="IActionReceiver.OnActionReceived"/> is called.
|
|||
/// </summary>
|
|||
internal readonly struct ActionBuffers |
|||
{ |
|||
/// <summary>
|
|||
/// An empty action buffer.
|
|||
/// </summary>
|
|||
public static ActionBuffers Empty = new ActionBuffers(ActionSegment<float>.Empty, ActionSegment<int>.Empty); |
|||
|
|||
/// <summary>
|
|||
/// Holds the Continuous <see cref="ActionSegment{T}"/> to be used by an <see cref="IActionReceiver"/>.
|
|||
/// </summary>
|
|||
public ActionSegment<float> ContinuousActions { get; } |
|||
|
|||
/// <summary>
|
|||
/// Holds the Discrete <see cref="ActionSegment{T}"/> to be used by an <see cref="IActionReceiver"/>.
|
|||
/// </summary>
|
|||
public ActionSegment<int> DiscreteActions { get; } |
|||
|
|||
/// <summary>
|
|||
/// Construct an <see cref="ActionBuffers"/> instance with the continuous and discrete actions that will
|
|||
/// be used.
|
|||
/// </summary>
|
|||
/// <param name="continuousActions">The continuous actions to send to an <see cref="IActionReceiver"/>.</param>
|
|||
/// <param name="discreteActions">The discrete actions to send to an <see cref="IActionReceiver"/>.</param>
|
|||
public ActionBuffers(ActionSegment<float> continuousActions, ActionSegment<int> discreteActions) |
|||
{ |
|||
ContinuousActions = continuousActions; |
|||
DiscreteActions = discreteActions; |
|||
} |
|||
|
|||
/// <inheritdoc cref="ValueType.Equals(object)"/>
|
|||
public override bool Equals(object obj) |
|||
{ |
|||
if (!(obj is ActionBuffers)) |
|||
{ |
|||
return false; |
|||
} |
|||
|
|||
var ab = (ActionBuffers)obj; |
|||
return ab.ContinuousActions.SequenceEqual(ContinuousActions) && |
|||
ab.DiscreteActions.SequenceEqual(DiscreteActions); |
|||
} |
|||
|
|||
/// <inheritdoc cref="ValueType.GetHashCode"/>
|
|||
public override int GetHashCode() |
|||
{ |
|||
unchecked |
|||
{ |
|||
return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode(); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// An interface that describes an object that can receive actions from a Reinforcement Learning network.
|
|||
/// </summary>
|
|||
internal interface IActionReceiver |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// The specification of the Action space for this IActionReceiver.
|
|||
/// </summary>
|
|||
/// <seealso cref="ActionSpec"/>
|
|||
ActionSpec ActionSpec { get; } |
|||
|
|||
/// <summary>
|
|||
/// Method called in order too allow object to execute actions based on the
|
|||
/// <see cref="ActionBuffers"/> contents. The structure of the contents in the <see cref="ActionBuffers"/>
|
|||
/// are defined by the <see cref="ActionSpec"/>.
|
|||
/// </summary>
|
|||
/// <param name="actionBuffers">The data structure containing the action buffers for this object.</param>
|
|||
void OnActionReceived(ActionBuffers actionBuffers); |
|||
|
|||
/// <summary>
|
|||
/// Implement `WriteDiscreteActionMask()` to modify the masks for discrete
|
|||
/// actions. When using discrete actions, the agent will not perform the masked
|
|||
/// action.
|
|||
/// </summary>
|
|||
/// <param name="actionMask">
|
|||
/// The action mask for the agent.
|
|||
/// </param>
|
|||
/// <remarks>
|
|||
/// When using Discrete Control, you can prevent the Agent from using a certain
|
|||
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask"/>.
|
|||
///
|
|||
/// See [Agents - Actions] for more information on masking actions.
|
|||
///
|
|||
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#actions
|
|||
/// </remarks>
|
|||
/// <seealso cref="IActionReceiver.OnActionReceived"/>
|
|||
void WriteDiscreteActionMask(IDiscreteActionMask actionMask); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: b25a5b3027c9476ea1a310241be0f10f |
|||
timeCreated: 1594756775 |
|
|||
using System; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// Abstraction that facilitates the execution of actions.
|
|||
/// </summary>
|
|||
internal interface IActuator : IActionReceiver |
|||
{ |
|||
int TotalNumberOfActions { get; } |
|||
|
|||
/// <summary>
|
|||
/// Gets the name of this IActuator which will be used to sort it.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
string Name { get; } |
|||
|
|||
void ResetData(); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 780d7f0a675f44bfa784b370025b51c3 |
|||
timeCreated: 1592848317 |
|
|||
using System.Collections.Generic; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// Interface for writing a mask to disable discrete actions for agents for the next decision.
|
|||
/// </summary>
|
|||
internal interface IDiscreteActionMask |
|||
{ |
|||
/// <summary>
|
|||
/// Modifies an action mask for discrete control agents.
|
|||
/// </summary>
|
|||
/// <remarks>
|
|||
/// When used, the agent will not be able to perform the actions passed as argument
|
|||
/// at the next decision for the specified action branch. The actionIndices correspond
|
|||
/// to the action options the agent will be unable to perform.
|
|||
///
|
|||
/// See [Agents - Actions] for more information on masking actions.
|
|||
///
|
|||
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_2_docs/docs/Learning-Environment-Design-Agents.md#actions
|
|||
/// </remarks>
|
|||
/// <param name="branch">The branch for which the actions will be masked.</param>
|
|||
/// <param name="actionIndices">The indices of the masked actions.</param>
|
|||
void WriteMask(int branch, IEnumerable<int> actionIndices); |
|||
|
|||
/// <summary>
|
|||
/// Get the current mask for an agent.
|
|||
/// </summary>
|
|||
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
|
|||
/// actions.</returns>
|
|||
bool[] GetMask(); |
|||
|
|||
/// <summary>
|
|||
/// Resets the current mask for an agent.
|
|||
/// </summary>
|
|||
void ResetMask(); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 1bc4e4b71bf4470789488fab2ee65388 |
|||
timeCreated: 1595369065 |
|
|||
using System; |
|||
|
|||
using Unity.MLAgents.Policies; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
internal class VectorActuator : IActuator |
|||
{ |
|||
IActionReceiver m_ActionReceiver; |
|||
|
|||
ActionBuffers m_ActionBuffers; |
|||
internal ActionBuffers ActionBuffers |
|||
{ |
|||
get => m_ActionBuffers; |
|||
private set => m_ActionBuffers = value; |
|||
} |
|||
|
|||
public VectorActuator(IActionReceiver actionReceiver, |
|||
int[] vectorActionSize, |
|||
SpaceType spaceType, |
|||
string name = "VectorActuator") |
|||
{ |
|||
m_ActionReceiver = actionReceiver; |
|||
string suffix; |
|||
switch (spaceType) |
|||
{ |
|||
case SpaceType.Continuous: |
|||
ActionSpec = ActionSpec.MakeContinuous(vectorActionSize[0]); |
|||
suffix = "-Continuous"; |
|||
break; |
|||
case SpaceType.Discrete: |
|||
ActionSpec = ActionSpec.MakeDiscrete(vectorActionSize); |
|||
suffix = "-Discrete"; |
|||
break; |
|||
default: |
|||
throw new ArgumentOutOfRangeException(nameof(spaceType), |
|||
spaceType, |
|||
"Unknown enum value."); |
|||
} |
|||
Name = name + suffix; |
|||
} |
|||
|
|||
public void ResetData() |
|||
{ |
|||
m_ActionBuffers = ActionBuffers.Empty; |
|||
} |
|||
|
|||
public void OnActionReceived(ActionBuffers actionBuffers) |
|||
{ |
|||
ActionBuffers = actionBuffers; |
|||
m_ActionReceiver.OnActionReceived(ActionBuffers); |
|||
} |
|||
|
|||
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) |
|||
{ |
|||
m_ActionReceiver.WriteDiscreteActionMask(actionMask); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns the number of discrete branches + the number of continuous actions.
|
|||
/// </summary>
|
|||
public int TotalNumberOfActions => ActionSpec.NumContinuousActions + |
|||
ActionSpec.NumDiscreteActions; |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IActionReceiver.ActionSpec"/>
|
|||
/// </summary>
|
|||
public ActionSpec ActionSpec { get; } |
|||
|
|||
public string Name { get; } |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: ff7a3292c0b24b23b3f1c0eeb690ec4c |
|||
timeCreated: 1593023833 |
|
|||
using System; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Actuators; |
|||
|
|||
namespace Unity.MLAgents.Tests.Actuators |
|||
{ |
|||
[TestFixture] |
|||
public class ActionSegmentTests |
|||
{ |
|||
[Test] |
|||
public void TestConstruction() |
|||
{ |
|||
var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f }; |
|||
Assert.Throws<ArgumentOutOfRangeException>( |
|||
() => new ActionSegment<float>(floatArray, 100, 1)); |
|||
|
|||
var segment = new ActionSegment<float>(Array.Empty<float>(), 0, 0); |
|||
Assert.AreEqual(segment, ActionSegment<float>.Empty); |
|||
} |
|||
[Test] |
|||
public void TestIndexing() |
|||
{ |
|||
var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f }; |
|||
for (var i = 0; i < floatArray.Length; i++) |
|||
{ |
|||
var start = 0 + i; |
|||
var length = floatArray.Length - i; |
|||
var actionSegment = new ActionSegment<float>(floatArray, start, length); |
|||
for (var j = 0; j < actionSegment.Length; j++) |
|||
{ |
|||
Assert.AreEqual(actionSegment[j], floatArray[start + j]); |
|||
} |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestEnumerator() |
|||
{ |
|||
var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f }; |
|||
for (var i = 0; i < floatArray.Length; i++) |
|||
{ |
|||
var start = 0 + i; |
|||
var length = floatArray.Length - i; |
|||
var actionSegment = new ActionSegment<float>(floatArray, start, length); |
|||
var j = 0; |
|||
foreach (var item in actionSegment) |
|||
{ |
|||
Assert.AreEqual(item, floatArray[start + j++]); |
|||
} |
|||
} |
|||
} |
|||
|
|||
} |
|||
|
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 18cb6d052fba43a2b7437d87c0d9abad |
|||
timeCreated: 1596486604 |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Actuators; |
|||
|
|||
namespace Unity.MLAgents.Tests.Actuators |
|||
{ |
|||
[TestFixture] |
|||
public class ActuatorDiscreteActionMaskTests |
|||
{ |
|||
[Test] |
|||
public void Construction() |
|||
{ |
|||
var masker = new ActuatorDiscreteActionMask(new List<IActuator>(), 0, 0); |
|||
Assert.IsNotNull(masker); |
|||
} |
|||
|
|||
[Test] |
|||
public void NullMask() |
|||
{ |
|||
var masker = new ActuatorDiscreteActionMask(new List<IActuator>(), 0, 0); |
|||
var mask = masker.GetMask(); |
|||
Assert.IsNull(mask); |
|||
} |
|||
|
|||
[Test] |
|||
public void FirstBranchMask() |
|||
{ |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1"); |
|||
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3); |
|||
var mask = masker.GetMask(); |
|||
Assert.IsNull(mask); |
|||
masker.WriteMask(0, new[] {1, 2, 3}); |
|||
mask = masker.GetMask(); |
|||
Assert.IsFalse(mask[0]); |
|||
Assert.IsTrue(mask[1]); |
|||
Assert.IsTrue(mask[2]); |
|||
Assert.IsTrue(mask[3]); |
|||
Assert.IsFalse(mask[4]); |
|||
Assert.AreEqual(mask.Length, 15); |
|||
} |
|||
|
|||
[Test] |
|||
public void SecondBranchMask() |
|||
{ |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1"); |
|||
var masker = new ActuatorDiscreteActionMask(new[] {actuator1}, 15, 3); |
|||
masker.WriteMask(1, new[] {1, 2, 3}); |
|||
var mask = masker.GetMask(); |
|||
Assert.IsFalse(mask[0]); |
|||
Assert.IsFalse(mask[4]); |
|||
Assert.IsTrue(mask[5]); |
|||
Assert.IsTrue(mask[6]); |
|||
Assert.IsTrue(mask[7]); |
|||
Assert.IsFalse(mask[8]); |
|||
Assert.IsFalse(mask[9]); |
|||
} |
|||
|
|||
[Test] |
|||
public void MaskReset() |
|||
{ |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1"); |
|||
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3); |
|||
masker.WriteMask(1, new[] {1, 2, 3}); |
|||
masker.ResetMask(); |
|||
var mask = masker.GetMask(); |
|||
for (var i = 0; i < 15; i++) |
|||
{ |
|||
Assert.IsFalse(mask[i]); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void ThrowsError() |
|||
{ |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1"); |
|||
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3); |
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.WriteMask(0, new[] {5})); |
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.WriteMask(1, new[] {5})); |
|||
masker.WriteMask(2, new[] {5}); |
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.WriteMask(3, new[] {1})); |
|||
masker.GetMask(); |
|||
masker.ResetMask(); |
|||
masker.WriteMask(0, new[] {0, 1, 2, 3}); |
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.GetMask()); |
|||
} |
|||
|
|||
[Test] |
|||
public void MultipleMaskEdit() |
|||
{ |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1"); |
|||
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3); |
|||
masker.WriteMask(0, new[] {0, 1}); |
|||
masker.WriteMask(0, new[] {3}); |
|||
masker.WriteMask(2, new[] {1}); |
|||
var mask = masker.GetMask(); |
|||
for (var i = 0; i < 15; i++) |
|||
{ |
|||
if ((i == 0) || (i == 1) || (i == 3) || (i == 10)) |
|||
{ |
|||
Assert.IsTrue(mask[i]); |
|||
} |
|||
else |
|||
{ |
|||
Assert.IsFalse(mask[i]); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: b9f5f87049d04d8bba39d193a3ab2f5a |
|||
timeCreated: 1596491682 |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using System.Linq; |
|||
using System.Text.RegularExpressions; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Actuators; |
|||
using Unity.MLAgents.Policies; |
|||
using UnityEngine; |
|||
using UnityEngine.TestTools; |
|||
using Assert = UnityEngine.Assertions.Assert; |
|||
|
|||
namespace Unity.MLAgents.Tests.Actuators |
|||
{ |
|||
[TestFixture] |
|||
public class ActuatorManagerTests |
|||
{ |
|||
[Test] |
|||
public void TestEnsureBufferSizeContinuous() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(10), "actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(2), "actuator2"); |
|||
manager.Add(actuator1); |
|||
manager.Add(actuator2); |
|||
var actuator1ActionSpaceDef = actuator1.ActionSpec; |
|||
var actuator2ActionSpaceDef = actuator2.ActionSpec; |
|||
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, |
|||
actuator1ActionSpaceDef.NumContinuousActions + actuator2ActionSpaceDef.NumContinuousActions, |
|||
actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes, |
|||
actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions); |
|||
|
|||
manager.UpdateActions(new[] |
|||
{ 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f }, Array.Empty<int>()); |
|||
|
|||
Assert.IsTrue(12 == manager.NumContinuousActions); |
|||
Assert.IsTrue(0 == manager.NumDiscreteActions); |
|||
Assert.IsTrue(0 == manager.SumOfDiscreteBranchSizes); |
|||
Assert.IsTrue(12 == manager.StoredContinuousActions.Length); |
|||
Assert.IsTrue(0 == manager.StoredDiscreteActions.Length); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestEnsureBufferDiscrete() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 1, 1}), "actuator2"); |
|||
manager.Add(actuator1); |
|||
manager.Add(actuator2); |
|||
var actuator1ActionSpaceDef = actuator1.ActionSpec; |
|||
var actuator2ActionSpaceDef = actuator2.ActionSpec; |
|||
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, |
|||
actuator1ActionSpaceDef.NumContinuousActions + actuator2ActionSpaceDef.NumContinuousActions, |
|||
actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes, |
|||
actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions); |
|||
|
|||
manager.UpdateActions(Array.Empty<float>(), |
|||
new[] { 0, 1, 2, 3, 4, 5, 6}); |
|||
|
|||
Assert.IsTrue(0 == manager.NumContinuousActions); |
|||
Assert.IsTrue(7 == manager.NumDiscreteActions); |
|||
Assert.IsTrue(13 == manager.SumOfDiscreteBranchSizes); |
|||
Assert.IsTrue(0 == manager.StoredContinuousActions.Length); |
|||
Assert.IsTrue(7 == manager.StoredDiscreteActions.Length); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestFailOnMixedActionSpace() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2"); |
|||
manager.Add(actuator1); |
|||
manager.Add(actuator2); |
|||
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, 3, 10, 4); |
|||
LogAssert.Expect(LogType.Assert, "Actuators on the same Agent must have the same action SpaceType."); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestFailOnSameActuatorName() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator1"); |
|||
manager.Add(actuator1); |
|||
manager.Add(actuator2); |
|||
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, 3, 10, 4); |
|||
LogAssert.Expect(LogType.Assert, "Actuator names must be unique."); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestExecuteActionsDiscrete() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 1, 1}), "actuator2"); |
|||
manager.Add(actuator1); |
|||
manager.Add(actuator2); |
|||
|
|||
var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5, 6}; |
|||
manager.UpdateActions(Array.Empty<float>(), |
|||
discreteActionBuffer); |
|||
|
|||
manager.ExecuteActions(); |
|||
var actuator1Actions = actuator1.LastActionBuffer.DiscreteActions; |
|||
var actuator2Actions = actuator2.LastActionBuffer.DiscreteActions; |
|||
TestSegmentEquality(actuator1Actions, discreteActionBuffer); TestSegmentEquality(actuator2Actions, discreteActionBuffer); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestExecuteActionsContinuous() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3), |
|||
"actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2"); |
|||
manager.Add(actuator1); |
|||
manager.Add(actuator2); |
|||
|
|||
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f}; |
|||
manager.UpdateActions(continuousActionBuffer, |
|||
Array.Empty<int>()); |
|||
|
|||
manager.ExecuteActions(); |
|||
var actuator1Actions = actuator1.LastActionBuffer.ContinuousActions; |
|||
var actuator2Actions = actuator2.LastActionBuffer.ContinuousActions; |
|||
TestSegmentEquality(actuator1Actions, continuousActionBuffer); |
|||
TestSegmentEquality(actuator2Actions, continuousActionBuffer); |
|||
} |
|||
|
|||
static void TestSegmentEquality<T>(ActionSegment<T> actionSegment, T[] actionBuffer) |
|||
where T : struct |
|||
{ |
|||
Assert.IsFalse(actionSegment.Length == 0); |
|||
for (var i = 0; i < actionSegment.Length; i++) |
|||
{ |
|||
var action = actionSegment[i]; |
|||
Assert.AreEqual(action, actionBuffer[actionSegment.Offset + i]); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestUpdateActionsContinuous() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3), |
|||
"actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2"); |
|||
manager.Add(actuator1); |
|||
manager.Add(actuator2); |
|||
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f}; |
|||
manager.UpdateActions(continuousActionBuffer, |
|||
Array.Empty<int>()); |
|||
|
|||
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestUpdateActionsDiscrete() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), |
|||
"actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2"); |
|||
manager.Add(actuator1); |
|||
manager.Add(actuator2); |
|||
var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5}; |
|||
manager.UpdateActions(Array.Empty<float>(), |
|||
discreteActionBuffer); |
|||
|
|||
Debug.Log(manager.StoredDiscreteActions); |
|||
Debug.Log(discreteActionBuffer); |
|||
Assert.IsTrue(manager.StoredDiscreteActions.SequenceEqual(discreteActionBuffer)); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestRemove() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), |
|||
"actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2"); |
|||
|
|||
manager.Add(actuator1); |
|||
manager.Add(actuator2); |
|||
Assert.IsTrue(manager.NumDiscreteActions == 6); |
|||
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 12); |
|||
|
|||
manager.Remove(actuator2); |
|||
|
|||
Assert.IsTrue(manager.NumDiscreteActions == 3); |
|||
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6); |
|||
|
|||
manager.Remove(null); |
|||
|
|||
Assert.IsTrue(manager.NumDiscreteActions == 3); |
|||
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6); |
|||
|
|||
manager.RemoveAt(0); |
|||
Assert.IsTrue(manager.NumDiscreteActions == 0); |
|||
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 0); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestClear() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }), |
|||
"actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2"); |
|||
manager.Add(actuator1); |
|||
manager.Add(actuator2); |
|||
|
|||
Assert.IsTrue(manager.NumDiscreteActions == 6); |
|||
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 12); |
|||
|
|||
manager.Clear(); |
|||
|
|||
Assert.IsTrue(manager.NumDiscreteActions == 0); |
|||
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 0); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestIndexSet() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4}), |
|||
"actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2"); |
|||
manager.Add(actuator1); |
|||
Assert.IsTrue(manager.NumDiscreteActions == 4); |
|||
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 10); |
|||
manager[0] = actuator2; |
|||
Assert.IsTrue(manager.NumDiscreteActions == 3); |
|||
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestInsert() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4}), |
|||
"actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2"); |
|||
manager.Add(actuator1); |
|||
Assert.IsTrue(manager.NumDiscreteActions == 4); |
|||
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 10); |
|||
manager.Insert(0, actuator2); |
|||
Assert.IsTrue(manager.NumDiscreteActions == 7); |
|||
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 16); |
|||
Assert.IsTrue(manager.IndexOf(actuator2) == 0); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestResetData() |
|||
{ |
|||
var manager = new ActuatorManager(); |
|||
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3), |
|||
"actuator1"); |
|||
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2"); |
|||
manager.Add(actuator1); |
|||
manager.Add(actuator2); |
|||
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f}; |
|||
manager.UpdateActions(continuousActionBuffer, |
|||
Array.Empty<int>()); |
|||
|
|||
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer)); |
|||
Assert.IsTrue(manager.NumContinuousActions == 6); |
|||
manager.ResetData(); |
|||
|
|||
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(new[] { 0f, 0f, 0f, 0f, 0f, 0f})); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestWriteDiscreteActionMask() |
|||
{ |
|||
var manager = new ActuatorManager(2); |
|||
var va1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "name"); |
|||
var va2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {3, 2, 1}), "name1"); |
|||
manager.Add(va1); |
|||
manager.Add(va2); |
|||
|
|||
var groundTruthMask = new[] |
|||
{ |
|||
false, |
|||
true, false, |
|||
false, true, true, |
|||
true, false, true, |
|||
false, true, |
|||
false |
|||
}; |
|||
|
|||
va1.Masks = new[] |
|||
{ |
|||
Array.Empty<int>(), |
|||
new[] { 0 }, |
|||
new[] { 1, 2 } |
|||
}; |
|||
|
|||
va2.Masks = new[] |
|||
{ |
|||
new[] {0, 2}, |
|||
new[] {1}, |
|||
Array.Empty<int>() |
|||
}; |
|||
manager.WriteActionMask(); |
|||
Assert.IsTrue(groundTruthMask.SequenceEqual(manager.DiscreteActionMask.GetMask())); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: d48ba72f0ac64d7db0af22c9d82b11d8 |
|||
timeCreated: 1596494279 |
|
|||
using Unity.MLAgents.Actuators; |
|||
namespace Unity.MLAgents.Tests.Actuators |
|||
{ |
|||
internal class TestActuator : IActuator |
|||
{ |
|||
public ActionBuffers LastActionBuffer; |
|||
public int[][] Masks; |
|||
public TestActuator(ActionSpec actuatorSpace, string name) |
|||
{ |
|||
ActionSpec = actuatorSpace; |
|||
TotalNumberOfActions = actuatorSpace.NumContinuousActions + |
|||
actuatorSpace.NumDiscreteActions; |
|||
Name = name; |
|||
} |
|||
|
|||
public void OnActionReceived(ActionBuffers actionBuffers) |
|||
{ |
|||
LastActionBuffer = actionBuffers; |
|||
} |
|||
|
|||
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) |
|||
{ |
|||
for (var i = 0; i < Masks.Length; i++) |
|||
{ |
|||
actionMask.WriteMask(i, Masks[i]); |
|||
} |
|||
} |
|||
|
|||
public int TotalNumberOfActions { get; } |
|||
public ActionSpec ActionSpec { get; } |
|||
|
|||
public string Name { get; } |
|||
|
|||
public void ResetData() |
|||
{ |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: fa950d7b175749bfa287fd8761dd831f |
|||
timeCreated: 1596665978 |
|
|||
using System.Collections.Generic; |
|||
using System.Linq; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Actuators; |
|||
using Unity.MLAgents.Policies; |
|||
using Assert = UnityEngine.Assertions.Assert; |
|||
|
|||
namespace Unity.MLAgents.Tests.Actuators |
|||
{ |
|||
[TestFixture] |
|||
public class VectorActuatorTests |
|||
{ |
|||
class TestActionReceiver : IActionReceiver |
|||
{ |
|||
public ActionBuffers LastActionBuffers; |
|||
public int Branch; |
|||
public IList<int> Mask; |
|||
public ActionSpec ActionSpec { get; } |
|||
|
|||
public void OnActionReceived(ActionBuffers actionBuffers) |
|||
{ |
|||
LastActionBuffers = actionBuffers; |
|||
} |
|||
|
|||
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) |
|||
{ |
|||
actionMask.WriteMask(Branch, Mask); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestConstruct() |
|||
{ |
|||
var ar = new TestActionReceiver(); |
|||
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name"); |
|||
|
|||
Assert.IsTrue(va.ActionSpec.NumDiscreteActions == 3); |
|||
Assert.IsTrue(va.ActionSpec.SumOfDiscreteBranchSizes == 6); |
|||
Assert.IsTrue(va.ActionSpec.NumContinuousActions == 0); |
|||
|
|||
var va1 = new VectorActuator(ar, new[] {4}, SpaceType.Continuous, "name"); |
|||
|
|||
Assert.IsTrue(va1.ActionSpec.NumContinuousActions == 4); |
|||
Assert.IsTrue(va1.ActionSpec.SumOfDiscreteBranchSizes == 0); |
|||
Assert.AreEqual(va1.Name, "name-Continuous"); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestOnActionReceived() |
|||
{ |
|||
var ar = new TestActionReceiver(); |
|||
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name"); |
|||
|
|||
var discreteActions = new[] { 0, 1, 1 }; |
|||
var ab = new ActionBuffers(ActionSegment<float>.Empty, |
|||
new ActionSegment<int>(discreteActions, 0, 3)); |
|||
|
|||
va.OnActionReceived(ab); |
|||
|
|||
Assert.AreEqual(ar.LastActionBuffers, ab); |
|||
va.ResetData(); |
|||
Assert.AreEqual(va.ActionBuffers.ContinuousActions, ActionSegment<float>.Empty); |
|||
Assert.AreEqual(va.ActionBuffers.DiscreteActions, ActionSegment<int>.Empty); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestResetData() |
|||
{ |
|||
var ar = new TestActionReceiver(); |
|||
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name"); |
|||
|
|||
var discreteActions = new[] { 0, 1, 1 }; |
|||
var ab = new ActionBuffers(ActionSegment<float>.Empty, |
|||
new ActionSegment<int>(discreteActions, 0, 3)); |
|||
|
|||
va.OnActionReceived(ab); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestWriteDiscreteActionMask() |
|||
{ |
|||
var ar = new TestActionReceiver(); |
|||
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name"); |
|||
var bdam = new ActuatorDiscreteActionMask(new[] {va}, 6, 3); |
|||
|
|||
var groundTruthMask = new[] { false, true, false, false, true, true }; |
|||
|
|||
ar.Branch = 1; |
|||
ar.Mask = new[] { 0 }; |
|||
va.WriteDiscreteActionMask(bdam); |
|||
ar.Branch = 2; |
|||
ar.Mask = new[] { 1, 2 }; |
|||
va.WriteDiscreteActionMask(bdam); |
|||
|
|||
Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask())); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: c2b191d2929f49adab0769705d49d86a |
|||
timeCreated: 1596580289 |
撰写
预览
正在加载...
取消
保存
Reference in new issue