浏览代码

Integrate IActuators into ML-Agents core code. (#4315)

Co-authored-by: Chris Elion <chris.elion@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
3a7572b4
共有 24 个文件被更改,包括 426 次插入418 次删除
  1. 33
      com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
  2. 2
      com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
  3. 51
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
  4. 72
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
  5. 2
      com.unity.ml-agents/Runtime/Actuators/IActuator.cs
  6. 2
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
  7. 2
      com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
  8. 201
      com.unity.ml-agents/Runtime/Agent.cs
  9. 14
      com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
  10. 2
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  11. 6
      com.unity.ml-agents/Runtime/DecisionRequester.cs
  12. 118
      com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
  13. 17
      com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
  14. 37
      com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
  15. 20
      com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
  16. 3
      com.unity.ml-agents/Runtime/Policies/IPolicy.cs
  17. 15
      com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs
  18. 46
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
  19. 3
      com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
  20. 4
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  21. 38
      com.unity.ml-agents/Runtime/Agent.deprecated.cs
  22. 3
      com.unity.ml-agents/Runtime/Agent.deprecated.cs.meta
  23. 11
      com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs.meta
  24. 142
      com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs

33
com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs


/// the offset into the original array, and an length.
/// </summary>
/// <typeparam name="T">The type of object stored in the underlying <see cref="Array"/></typeparam>
internal readonly struct ActionSegment<T> : IEnumerable<T>, IEquatable<ActionSegment<T>>
public readonly struct ActionSegment<T> : IEnumerable<T>, IEquatable<ActionSegment<T>>
where T : struct
{
/// <summary>

/// </summary>
public static ActionSegment<T> Empty = new ActionSegment<T>(System.Array.Empty<T>(), 0, 0);
static void CheckParameters(T[] actionArray, int offset, int length)
static void CheckParameters(IReadOnlyCollection<T> actionArray, int offset, int length)
if (offset + length > actionArray.Length)
if (offset + length > actionArray.Count)
$"are out of bounds of actionArray: {actionArray.Length}.");
$"are out of bounds of actionArray: {actionArray.Count}.");
/// Construct an <see cref="ActionSegment{T}"/> with just an actionArray. The <see cref="Offset"/> will
/// be set to 0 and the <see cref="Length"/> will be set to `actionArray.Length`.
/// </summary>
/// <param name="actionArray">The action array to use for the this segment.</param>
public ActionSegment(T[] actionArray) : this(actionArray, 0, actionArray.Length) { }
/// <summary>
/// Construct an <see cref="ActionSegment{T}"/> with an underlying array
/// and offset, and a length.
/// </summary>

public ActionSegment(T[] actionArray, int offset, int length)
{
#if DEBUG
#endif
Array = actionArray;
Offset = offset;
Length = length;

}
return Array[Offset + index];
}
set
{
if (index < 0 || index > Length)
{
throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}");
}
Array[Offset + index] = value;
}
}
/// <summary>
/// Sets the segment of the backing array to all zeros.
/// </summary>
public void Clear()
{
System.Array.Clear(Array, Offset, Length);
}
/// <inheritdoc cref="IEnumerable{T}.GetEnumerator"/>

2
com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs


/// <summary>
/// Defines the structure of an Action Space to be used by the Actuator system.
/// </summary>
internal readonly struct ActionSpec
public readonly struct ActionSpec
{
/// <summary>

51
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs


/// <summary>
/// Returns the previously stored actions for the actuators in this list.
/// </summary>
public float[] StoredContinuousActions { get; private set; }
// public float[] StoredContinuousActions { get; private set; }
public int[] StoredDiscreteActions { get; private set; }
// public int[] StoredDiscreteActions { get; private set; }
public ActionBuffers StoredActions { get; private set; }
/// <summary>
/// Create an ActuatorList with a preset capacity.

// Sort the Actuators by name to ensure determinism
SortActuators();
StoredContinuousActions = numContinuousActions == 0 ? Array.Empty<float>() : new float[numContinuousActions];
StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty<int>() : new int[numDiscreteBranches];
var continuousActions = numContinuousActions == 0 ? ActionSegment<float>.Empty :
new ActionSegment<float>(new float[numContinuousActions]);
var discreteActions = numDiscreteBranches == 0 ? ActionSegment<int>.Empty : new ActionSegment<int>(new int[numDiscreteBranches]);
StoredActions = new ActionBuffers(continuousActions, discreteActions);
m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches);
m_ReadyForExecution = true;
}

/// continuous actions for the IActuators in this list.</param>
/// <param name="discreteActionBuffer">The action buffer which contains all of the
/// discrete actions for the IActuators in this list.</param>
public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer)
public void UpdateActions(ActionBuffers actions)
UpdateActionArray(continuousActionBuffer, StoredContinuousActions);
UpdateActionArray(discreteActionBuffer, StoredDiscreteActions);
UpdateActionArray(actions.ContinuousActions, StoredActions.ContinuousActions);
UpdateActionArray(actions.DiscreteActions, StoredActions.DiscreteActions);
static void UpdateActionArray<T>(T[] sourceActionBuffer, T[] destination)
static void UpdateActionArray<T>(ActionSegment<T> sourceActionBuffer, ActionSegment<T> destination)
where T : struct
if (sourceActionBuffer == null || sourceActionBuffer.Length == 0)
if (sourceActionBuffer.Length <= 0)
Array.Clear(destination, 0, destination.Length);
destination.Clear();
}
else
{

Array.Copy(sourceActionBuffer, destination, destination.Length);
Array.Copy(sourceActionBuffer.Array,
sourceActionBuffer.Offset,
destination.Array,
destination.Offset,
destination.Length);
}
}

for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
m_DiscreteActionMask.CurrentBranchOffset = offset;
actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
offset += actuator.ActionSpec.NumDiscreteActions;
if (actuator.ActionSpec.NumDiscreteActions > 0)
{
m_DiscreteActionMask.CurrentBranchOffset = offset;
actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
offset += actuator.ActionSpec.NumDiscreteActions;
}
}
}

var continuousActions = ActionSegment<float>.Empty;
if (numContinuousActions > 0)
{
continuousActions = new ActionSegment<float>(StoredContinuousActions,
continuousActions = new ActionSegment<float>(StoredActions.ContinuousActions.Array,
continuousStart,
numContinuousActions);
}

{
discreteActions = new ActionSegment<int>(StoredDiscreteActions,
discreteActions = new ActionSegment<int>(StoredActions.DiscreteActions.Array,
discreteStart,
numDiscreteActions);
}

}
/// <summary>
/// Resets the <see cref="StoredContinuousActions"/> and <see cref="StoredDiscreteActions"/> buffers to be all
/// Resets the <see cref="ActionBuffers"/> to be all
/// zeros and calls <see cref="IActuator.ResetData"/> on each <see cref="IActuator"/> managed by this object.
/// </summary>
public void ResetData()

return;
}
Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length);
Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length);
StoredActions.Clear();
m_DiscreteActionMask.ResetMask();
}

72
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs


using System;
using System.Linq;
using UnityEngine;
namespace Unity.MLAgents.Actuators
{

/// </summary>
internal readonly struct ActionBuffers
public readonly struct ActionBuffers
{
/// <summary>
/// An empty action buffer.

public ActionSegment<int> DiscreteActions { get; }
/// <summary>
/// Create an <see cref="ActionBuffers"/> instance with discrete actions stored as a float array. This exists
/// to achieve backward compatibility with the former Agent methods which used a float array for both continuous
/// and discrete actions.
/// </summary>
/// <param name="discreteActions">The float array of discrete actions.</param>
/// <returns>An <see cref="ActionBuffers"/> instance initialized with a <see cref="DiscreteActions"/>
/// <see cref="ActionSegment{T}"/> initialized from a float array.</returns>
public static ActionBuffers FromDiscreteActions(float[] discreteActions)
{
return new ActionBuffers(ActionSegment<float>.Empty, discreteActions == null ? ActionSegment<int>.Empty
: new ActionSegment<int>(Array.ConvertAll(discreteActions,
x => (int)x)));
}
public ActionBuffers(float[] continuousActions, int[] discreteActions)
: this(new ActionSegment<float>(continuousActions), new ActionSegment<int>(discreteActions)) { }
/// <summary>
/// Construct an <see cref="ActionBuffers"/> instance with the continuous and discrete actions that will
/// be used.
/// </summary>

DiscreteActions = discreteActions;
}
/// <summary>
/// Clear the <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/> segments to be all zeros.
/// </summary>
public void Clear()
{
ContinuousActions.Clear();
DiscreteActions.Clear();
}
/// <inheritdoc cref="ValueType.Equals(object)"/>
public override bool Equals(object obj)
{

return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode();
}
}
/// <summary>
/// Packs the continuous and discrete actions into one float array. The array passed into this method
/// must have a Length that is greater than or equal to the sum of the Lengths of
/// <see cref="ContinuousActions"/> and <see cref="DiscreteActions"/>.
/// </summary>
/// <param name="destination">A float array to pack actions into whose length is greater than or
/// equal to the addition of the Lengths of this objects <see cref="ContinuousActions"/> and
/// <see cref="DiscreteActions"/> segments.</param>
public void PackActions(in float[] destination)
{
Debug.Assert(destination.Length >= ContinuousActions.Length + DiscreteActions.Length,
$"argument '{nameof(destination)}' is not large enough to pack the actions into.\n" +
$"{nameof(destination)}.Length: {destination.Length}\n" +
$"{nameof(ContinuousActions)}.Length + {nameof(DiscreteActions)}.Length: {ContinuousActions.Length + DiscreteActions.Length}");
var start = 0;
if (ContinuousActions.Length > 0)
{
Array.Copy(ContinuousActions.Array,
ContinuousActions.Offset,
destination,
start,
ContinuousActions.Length);
start = ContinuousActions.Length;
}
if (start >= destination.Length)
{
return;
}
if (DiscreteActions.Length > 0)
{
Array.Copy(DiscreteActions.Array,
DiscreteActions.Offset,
destination,
start,
DiscreteActions.Length);
}
}
internal interface IActionReceiver
public interface IActionReceiver
{
/// <summary>

2
com.unity.ml-agents/Runtime/Actuators/IActuator.cs


/// <summary>
/// Abstraction that facilitates the execution of actions.
/// </summary>
internal interface IActuator : IActionReceiver
public interface IActuator : IActionReceiver
{
int TotalNumberOfActions { get; }

2
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs


/// <summary>
/// Interface for writing a mask to disable discrete actions for agents for the next decision.
/// </summary>
internal interface IDiscreteActionMask
public interface IDiscreteActionMask
{
/// <summary>
/// Modifies an action mask for discrete control agents.

2
com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs


namespace Unity.MLAgents.Actuators
{
internal class VectorActuator : IActuator
public class VectorActuator : IActuator
{
IActionReceiver m_ActionReceiver;

201
com.unity.ml-agents/Runtime/Agent.cs


using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using Unity.MLAgents.Demonstrations;

/// to separate between different agents in the environment.
/// </summary>
public int episodeId;
}
/// <summary>
/// Struct that contains the action information sent from the Brain to the
/// Agent.
/// </summary>
internal struct AgentAction
{
public float[] vectorActions;
public void ClearActions()
{
Array.Clear(storedVectorActions, 0, storedVectorActions.Length);
}
public void CopyActions(ActionBuffers actionBuffers)
{
actionBuffers.PackActions(storedVectorActions);
}
}
/// <summary>

/// can only take an action when it touches the ground, so several frames might elapse between
/// one decision and the need for the next.
///
/// Use the <see cref="OnActionReceived"/> function to implement the actions your agent can take,
/// Use the <see cref="OnActionReceived(float[])"/> function to implement the actions your agent can take,
/// such as moving to reach a goal or interacting with its environment.
///
/// When you call <see cref="EndEpisode"/> on an agent or the agent reaches its <see cref="MaxStep"/> count,

"docs/Learning-Environment-Design-Agents.md")]
[Serializable]
[RequireComponent(typeof(BehaviorParameters))]
public class Agent : MonoBehaviour, ISerializationCallbackReceiver
public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver
{
IPolicy m_Brain;
BehaviorParameters m_PolicyFactory;

/// Current Agent information (message sent to Brain).
AgentInfo m_Info;
/// Current Agent action (message sent from Brain).
AgentAction m_Action;
/// Represents the reward the agent accumulated during the current step.
/// It is reset to 0 at the beginning of every step.

internal VectorSensor collectObservationsSensor;
/// <summary>
/// List of IActuators that this Agent will delegate actions to if any exist.
/// </summary>
ActuatorManager m_ActuatorManager;
/// <summary>
/// VectorActuator which is used by default if no other sensors exist on this Agent. This VectorSensor will
/// delegate its actions to <see cref="OnActionReceived(float[])"/> by default in order to keep backward compatibility
/// with the current behavior of Agent.
/// </summary>
IActuator m_VectorActuator;
/// <summary>
/// This is used to avoid allocation of a float array every frame if users are still using the old
/// OnActionReceived method.
/// </summary>
float[] m_LegacyActionCache;
/// <summary>
/// Called when the attached [GameObject] becomes enabled and active.
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
/// </summary>

m_PolicyFactory = GetComponent<BehaviorParameters>();
m_Info = new AgentInfo();
m_Action = new AgentAction();
sensors = new List<ISensor>();
Academy.Instance.AgentIncrementStep += AgentIncrementStep;

InitializeSensors();
}
using (TimerStack.Instance.Scoped("InitializeActuators"))
{
InitializeActuators();
}
m_Info.storedVectorActions = new float[m_ActuatorManager.TotalNumberOfActions];
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.
// To avoid the Agent resetting twice, the Agents will not begin their

/// set the reward assigned to the current step with a specific value rather than
/// increasing or decreasing it.
///
/// Typically, you assign rewards in the Agent subclass's <see cref="OnActionReceived(float[])"/>
/// Typically, you assign rewards in the Agent subclass's <see cref="IActionReceiver.OnActionReceived"/>
/// implementation after carrying out the received action and evaluating its success.
///
/// Rewards are used during reinforcement learning; they are ignored during inference.

/// <remarks>
/// Call `RequestAction()` to repeat the previous action returned by the agent's
/// most recent decision. A new decision is not requested. When you call this function,
/// the Agent instance invokes <seealso cref="OnActionReceived(float[])"/> with the
/// the Agent instance invokes <seealso cref="IActionReceiver.OnActionReceived"/> with the
/// existing action vector.
///
/// You can use `RequestAction()` in situations where an agent must take an action

/// at the end of an episode.
void ResetData()
{
var param = m_PolicyFactory.BrainParameters;
m_ActionMasker = new DiscreteActionMasker(param);
// If we haven't initialized vectorActions, initialize to 0. This should only
// happen during the creation of the Agent. In subsequent episodes, vectorAction
// should stay the previous action before the Done(), so that it is properly recorded.
if (m_Action.vectorActions == null)
{
m_Action.vectorActions = new float[param.NumActions];
m_Info.storedVectorActions = new float[param.NumActions];
}
m_ActuatorManager?.ResetData();
}
/// <summary>

/// control of an agent using keyboard, mouse, or game controller input.
///
/// Your heuristic implementation can use any decision making logic you specify. Assign decision
/// values to the float[] array, <paramref name="actionsOut"/>, passed to your function as a parameter.
/// values to the <see cref="ActionBuffers.ContinuousActions"/> and <see cref="ActionBuffers.DiscreteActions"/>
/// arrays , passed to your function as a parameter.
/// <seealso cref="OnActionReceived(float[])"/> function, which receives this array and
/// <seealso cref="IActionReceiver.OnActionReceived"/> function, which receives this array and
/// implements the corresponding agent behavior. See [Actions] for more information
/// about agent actions.
/// Note : Do not create a new float array of action in the `Heuristic()` method,

/// You can also use the [Input System package], which provides a more flexible and
/// configurable input system.
/// <code>
/// public override void Heuristic(float[] actionsOut)
/// public override void Heuristic(ActionBuffers actionsOut)
/// actionsOut[0] = Input.GetAxis("Horizontal");
/// actionsOut[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
/// actionsOut[2] = Input.GetAxis("Vertical");
/// actionsOut.ContinuousActions[0] = Input.GetAxis("Horizontal");
/// actionsOut.ContinuousActions[1] = Input.GetKey(KeyCode.Space) ? 1.0f : 0.0f;
/// actionsOut.ContinuousActions[2] = Input.GetAxis("Vertical");
/// <param name="actionsOut">Array for the output actions.</param>
/// <seealso cref="OnActionReceived(float[])"/>
public virtual void Heuristic(float[] actionsOut)
/// <param name="actionsOut">The <see cref="ActionBuffers"/> which contain the continuous and
/// discrete action buffers to write to.</param>
/// <seealso cref="IActionReceiver.OnActionReceived"/>
public virtual void Heuristic(in ActionBuffers actionsOut)
Debug.LogWarning("Heuristic method called but not implemented. Returning placeholder actions.");
Array.Clear(actionsOut, 0, actionsOut.Length);
// For backward compatibility
switch (m_PolicyFactory.BrainParameters.VectorActionSpaceType)
{
case SpaceType.Continuous:
Heuristic(actionsOut.ContinuousActions.Array);
actionsOut.DiscreteActions.Clear();
break;
case SpaceType.Discrete:
var convertedOut = Array.ConvertAll(actionsOut.DiscreteActions.Array, x => (float)x);
Heuristic(convertedOut);
var discreteActionSegment = actionsOut.DiscreteActions;
for (var i = 0; i < actionsOut.DiscreteActions.Length; i++)
{
discreteActionSegment[i] = (int)convertedOut[i];
}
actionsOut.ContinuousActions.Clear();
break;
}
}
/// <summary>

#if DEBUG
// Make sure the names are actually unique
for (var i = 0; i < sensors.Count - 1; i++)
{
Debug.Assert(

#endif
}
void InitializeActuators()
{
ActuatorComponent[] attachedActuators;
if (m_PolicyFactory.UseChildActuators)
{
attachedActuators = GetComponentsInChildren<ActuatorComponent>();
}
else
{
attachedActuators = GetComponents<ActuatorComponent>();
}
// Support legacy OnActionReceived
var param = m_PolicyFactory.BrainParameters;
m_VectorActuator = new VectorActuator(this, param.VectorActionSize, param.VectorActionSpaceType);
m_ActuatorManager = new ActuatorManager(attachedActuators.Length + 1);
m_LegacyActionCache = new float[m_VectorActuator.TotalNumberOfActions];
m_ActuatorManager.Add(m_VectorActuator);
foreach (var actuatorComponent in attachedActuators)
{
m_ActuatorManager.Add(actuatorComponent.CreateActuator());
}
}
/// <summary>
/// Sends the Agent info to the linked Brain.
/// </summary>

if (m_Info.done)
{
Array.Clear(m_Info.storedVectorActions, 0, m_Info.storedVectorActions.Length);
m_Info.ClearActions();
Array.Copy(m_Action.vectorActions, m_Info.storedVectorActions, m_Action.vectorActions.Length);
m_ActuatorManager.StoredActions.PackActions(m_Info.storedVectorActions);
m_ActionMasker.ResetMask();
UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{

{
if (m_PolicyFactory.BrainParameters.VectorActionSpaceType == SpaceType.Discrete)
{
CollectDiscreteActionMasks(m_ActionMasker);
}
m_ActuatorManager.WriteActionMask();
m_Info.discreteActionMasks = m_ActionMasker.GetMask();
m_Info.discreteActionMasks = m_ActuatorManager.DiscreteActionMask?.GetMask();
m_Info.reward = m_Reward;
m_Info.done = false;
m_Info.maxStepReached = false;

/// <summary>
/// Returns a read-only view of the observations that were generated in
/// <see cref="CollectObservations(VectorSensor)"/>. This is mainly useful inside of a
/// <see cref="Heuristic(float[])"/> method to avoid recomputing the observations.
/// <see cref="Heuristic(float[], int[])"/> method to avoid recomputing the observations.
/// </summary>
/// <returns>A read-only view of the observations list.</returns>
public ReadOnlyCollection<float> GetObservations()

///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <seealso cref="OnActionReceived(float[])"/>
public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
/// <seealso cref="IActionReceiver.OnActionReceived"/>
public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
if (m_ActionMasker == null)
{
m_ActionMasker = new DiscreteActionMasker(actionMask);
}
CollectDiscreteActionMasks(m_ActionMasker);
ActionSpec IActionReceiver.ActionSpec { get; }
/// <summary>
/// Implement `OnActionReceived()` to specify agent behavior at every step, based

/// three values in the action array to use as the force components. During
/// training, the agent's policy learns to set those particular elements of
/// the array to maximize the training rewards the agent receives. (Of course,
/// if you implement a <seealso cref="Heuristic"/> function, it must use the same
/// if you implement a <seealso cref="Heuristic(float[], int[])"/> function, it must use the same
/// elements of the action array for the same purpose since there is no learning
/// involved.)
///

///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_5_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <param name="vectorAction">
/// An array containing the action vector. The length of the array is specified
/// by the <see cref="BrainParameters"/> of the agent's associated
/// <see cref="BehaviorParameters"/> component.
/// <param name="actions">
/// Struct containing the buffers of actions to be executed at this step.
public virtual void OnActionReceived(float[] vectorAction) {}
public virtual void OnActionReceived(ActionBuffers actions)
{
actions.PackActions(m_LegacyActionCache);
OnActionReceived(m_LegacyActionCache);
}
/// <summary>
/// Implement `OnEpisodeBegin()` to set up an Agent instance at the beginning

public virtual void OnEpisodeBegin() {}
/// <summary>
/// Returns the last action that was decided on by the Agent.
/// Gets the last ActionBuffer for this agent.
/// <returns>
/// The last action that was decided by the Agent (or null if no decision has been made).
/// </returns>
/// <seealso cref="OnActionReceived(float[])"/>
public float[] GetAction()
public ActionBuffers GetStoredContinuousActions()
return m_Action.vectorActions;
return m_ActuatorManager.StoredActions;
}
/// <summary>

if ((m_RequestAction) && (m_Brain != null))
{
m_RequestAction = false;
OnActionReceived(m_Action.vectorActions);
m_ActuatorManager.ExecuteActions();
}
if ((m_StepCount >= MaxStep) && (MaxStep > 0))

void DecideAction()
{
if (m_Action.vectorActions == null)
if (m_ActuatorManager.StoredActions.ContinuousActions.Array == null)
var action = m_Brain?.DecideAction();
if (action == null)
{
Array.Clear(m_Action.vectorActions, 0, m_Action.vectorActions.Length);
}
else
{
Array.Copy(action, m_Action.vectorActions, action.Length);
}
var actions = m_Brain?.DecideAction() ?? new ActionBuffers();
m_Info.CopyActions(actions);
m_ActuatorManager.UpdateActions(actions);
}
}
}

14
com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs


}
#region AgentAction
public static AgentAction ToAgentAction(this AgentActionProto aap)
{
return new AgentAction
{
vectorActions = aap.VectorActions.ToArray()
};
}
public static List<AgentAction> ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto)
public static List<float[]> ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto)
var agentActions = new List<AgentAction>(proto.Value.Count);
var agentActions = new List<float[]>(proto.Value.Count);
agentActions.Add(ap.ToAgentAction());
agentActions.Add(ap.VectorActions.ToArray());
}
return agentActions;
}

2
com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs


var agentId = m_OrderedAgentsRequestingDecisions[brainName][i];
if (m_LastActionsReceived[brainName].ContainsKey(agentId))
{
m_LastActionsReceived[brainName][agentId] = agentAction.vectorActions;
m_LastActionsReceived[brainName][agentId] = agentAction;
}
}
}

6
com.unity.ml-agents/Runtime/DecisionRequester.cs


/// that the Agent will request a decision every 5 Academy steps. /// </summary>
[Range(1, 20)]
[Tooltip("The frequency with which the agent requests a decision. A DecisionPeriod " +
"of 5 means that the Agent will request a decision every 5 Academy steps.")]
"of 5 means that the Agent will request a decision every 5 Academy steps.")]
public int DecisionPeriod = 5;
/// <summary>

[Tooltip("Indicates whether or not the agent will take an action during the Academy " +
"steps where it does not request a decision. Has no effect when DecisionPeriod " +
"is set to 1.")]
"steps where it does not request a decision. Has no effect when DecisionPeriod " +
"is set to 1.")]
[FormerlySerializedAs("RepeatAction")]
public bool TakeActionsBetweenDecisions = true;

118
com.unity.ml-agents/Runtime/DiscreteActionMasker.cs


using System;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Actuators;
namespace Unity.MLAgents
{

/// may be illegal. For example, if an agent is adjacent to a wall or other obstacle
/// you could mask any actions that direct the agent to move into the blocked space.
/// </remarks>
public class DiscreteActionMasker
public class DiscreteActionMasker : IDiscreteActionMask
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.
int[] m_StartingActionIndices;
bool[] m_CurrentMask;
readonly BrainParameters m_BrainParameters;
IDiscreteActionMask m_Delegate;
internal DiscreteActionMasker(BrainParameters brainParameters)
internal DiscreteActionMasker(IDiscreteActionMask actionMask)
m_BrainParameters = brainParameters;
m_Delegate = actionMask;
}
/// <summary>

/// <param name="actionIndices">The indices of the masked actions.</param>
public void SetMask(int branch, IEnumerable<int> actionIndices)
{
// If the branch does not exist, raise an error
if (branch >= m_BrainParameters.VectorActionSize.Length)
throw new UnityAgentsException(
"Invalid Action Masking : Branch " + branch + " does not exist.");
var totalNumberActions = m_BrainParameters.VectorActionSize.Sum();
// By default, the masks are null. If we want to specify a new mask, we initialize
// the actionMasks with trues.
if (m_CurrentMask == null)
{
m_CurrentMask = new bool[totalNumberActions];
}
// If this is the first time the masked actions are used, we generate the starting
// indices for each branch.
if (m_StartingActionIndices == null)
{
m_StartingActionIndices = Utilities.CumSum(m_BrainParameters.VectorActionSize);
}
// Perform the masking
foreach (var actionIndex in actionIndices)
{
if (actionIndex >= m_BrainParameters.VectorActionSize[branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
m_CurrentMask[actionIndex + m_StartingActionIndices[branch]] = true;
}
}
/// <summary>
/// Get the current mask for an agent.
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
internal bool[] GetMask()
{
if (m_CurrentMask != null)
{
AssertMask();
}
return m_CurrentMask;
m_Delegate.WriteMask(branch, actionIndices);
/// <summary>
/// Makes sure that the current mask is usable.
/// </summary>
void AssertMask()
public void WriteMask(int branch, IEnumerable<int> actionIndices)
// Action Masks can only be used in Discrete Control.
if (m_BrainParameters.VectorActionSpaceType != SpaceType.Discrete)
{
throw new UnityAgentsException(
"Invalid Action Masking : Can only set action mask for Discrete Control.");
}
var numBranches = m_BrainParameters.VectorActionSize.Length;
for (var branchIndex = 0; branchIndex < numBranches; branchIndex++)
{
if (AreAllActionsMasked(branchIndex))
{
throw new UnityAgentsException(
"Invalid Action Masking : All the actions of branch " + branchIndex +
" are masked.");
}
}
m_Delegate.WriteMask(branch, actionIndices);
/// <summary>
/// Resets the current mask for an agent.
/// </summary>
internal void ResetMask()
public bool[] GetMask()
if (m_CurrentMask != null)
{
Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length);
}
return m_Delegate.GetMask();
/// <summary>
/// Checks if all the actions in the input branch are masked.
/// </summary>
/// <param name="branch"> The index of the branch to check.</param>
/// <returns> True if all the actions of the branch are masked.</returns>
bool AreAllActionsMasked(int branch)
public void ResetMask()
if (m_CurrentMask == null)
{
return false;
}
var start = m_StartingActionIndices[branch];
var end = m_StartingActionIndices[branch + 1];
for (var i = start; i < end; i++)
{
if (!m_CurrentMask[i])
{
return false;
}
}
return true;
m_Delegate.ResetMask();
}
}
}

17
com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs


using System;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Sensors;

internal class BarracudaPolicy : IPolicy
{
protected ModelRunner m_ModelRunner;
ActionBuffers m_LastActionBuffer;
int m_AgentId;

List<int[]> m_SensorShapes;
SpaceType m_SpaceType;
/// <inheritdoc />
public BarracudaPolicy(

{
var modelRunner = Academy.Instance.GetOrCreateModelRunner(model, brainParameters, inferenceDevice);
m_ModelRunner = modelRunner;
m_SpaceType = brainParameters.VectorActionSpaceType;
}
/// <inheritdoc />

}
/// <inheritdoc />
public float[] DecideAction()
public ref readonly ActionBuffers DecideAction()
return m_ModelRunner?.GetAction(m_AgentId);
var actions = m_ModelRunner?.GetAction(m_AgentId);
if (m_SpaceType == SpaceType.Continuous)
{
m_LastActionBuffer = new ActionBuffers(actions, Array.Empty<int>());
return ref m_LastActionBuffer;
}
m_LastActionBuffer = ActionBuffers.FromDiscreteActions(actions);
return ref m_LastActionBuffer;
}
public void Dispose()

37
com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs


[Tooltip("Use all Sensor components attached to child GameObjects of this Agent.")]
bool m_UseChildSensors = true;
[HideInInspector]
[SerializeField]
[Tooltip("Use all Actuator components attached to child GameObjects of this Agent.")]
bool m_UseChildActuators = true;
/// <summary>
/// Whether or not to use all the sensor components attached to child GameObjects of the agent.
/// Note that changing this after the Agent has been initialized will not have any effect.

set { m_UseChildSensors = value; }
}
/// <summary>
/// Whether or not to use all the actuator components attached to child GameObjects of the agent.
/// Note that changing this after the Agent has been initialized will not have any effect.
/// </summary>
public bool UseChildActuators
{
get { return m_UseChildActuators; }
set { m_UseChildActuators = value; }
}
[HideInInspector, SerializeField]
ObservableAttributeOptions m_ObservableAttributeHandling = ObservableAttributeOptions.Ignore;

switch (m_BehaviorType)
{
case BehaviorType.HeuristicOnly:
return new HeuristicPolicy(heuristic, m_BrainParameters.NumActions);
return GenerateHeuristicPolicy(heuristic);
case BehaviorType.InferenceOnly:
{
if (m_Model == null)

}
else
{
return new HeuristicPolicy(heuristic, m_BrainParameters.NumActions);
return GenerateHeuristicPolicy(heuristic);
return new HeuristicPolicy(heuristic, m_BrainParameters.NumActions);
return GenerateHeuristicPolicy(heuristic);
}
internal IPolicy GenerateHeuristicPolicy(HeuristicPolicy.ActionGenerator heuristic)
{
var numContinuousActions = 0;
var numDiscreteActions = 0;
if (m_BrainParameters.VectorActionSpaceType == SpaceType.Continuous)
{
numContinuousActions = m_BrainParameters.NumActions;
}
else if (m_BrainParameters.VectorActionSpaceType == SpaceType.Discrete)
{
numDiscreteActions = m_BrainParameters.NumActions;
}
return new HeuristicPolicy(heuristic, numContinuousActions, numDiscreteActions);
}
internal void UpdateAgentPolicy()

20
com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs


using System.Collections.Generic;
using System;
using System.Collections;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Policies

/// </summary>
internal class HeuristicPolicy : IPolicy
{
public delegate void ActionGenerator(float[] actionsOut);
public delegate void ActionGenerator(in ActionBuffers actionBuffers);
float[] m_LastDecision;
ActionBuffers m_ActionBuffers;
bool m_Done;
bool m_DecisionRequested;

/// <inheritdoc />
public HeuristicPolicy(ActionGenerator heuristic, int numActions)
public HeuristicPolicy(ActionGenerator heuristic, int numContinuousActions, int numDiscreteActions)
m_LastDecision = new float[numActions];
var continuousDecision = new ActionSegment<float>(new float[numContinuousActions], 0, numContinuousActions);
var discreteDecision = new ActionSegment<int>(new int[numDiscreteActions], 0, numDiscreteActions);
m_ActionBuffers = new ActionBuffers(continuousDecision, discreteDecision);
}
/// <inheritdoc />

m_Done = info.done;
m_DecisionRequested = true;
public float[] DecideAction()
public ref readonly ActionBuffers DecideAction()
m_Heuristic.Invoke(m_LastDecision);
m_Heuristic.Invoke(m_ActionBuffers);
return m_LastDecision;
return ref m_ActionBuffers;
}
public void Dispose()

public float this[int index]
{
get { return 0.0f; }
set { }
set {}
}
}

3
com.unity.ml-agents/Runtime/Policies/IPolicy.cs


using System;
using System.Collections.Generic;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Policies

/// it must be taken now. The Brain is expected to update the actions
/// of the Agents at this point the latest.
/// </summary>
float[] DecideAction();
ref readonly ActionBuffers DecideAction();
}
}

15
com.unity.ml-agents/Runtime/Policies/RemotePolicy.cs


using UnityEngine;
using System.Collections.Generic;
using System;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Policies

{
int m_AgentId;
string m_FullyQualifiedBehaviorName;
SpaceType m_SpaceType;
ActionBuffers m_LastActionBuffer;
internal ICommunicator m_Communicator;

{
m_FullyQualifiedBehaviorName = fullyQualifiedBehaviorName;
m_Communicator = Academy.Instance.Communicator;
m_SpaceType = brainParameters.VectorActionSpaceType;
m_Communicator.SubscribeBrain(m_FullyQualifiedBehaviorName, brainParameters);
}

}
/// <inheritdoc />
public float[] DecideAction()
public ref readonly ActionBuffers DecideAction()
return m_Communicator?.GetActions(m_FullyQualifiedBehaviorName, m_AgentId);
var actions = m_Communicator?.GetActions(m_FullyQualifiedBehaviorName, m_AgentId);
if (m_SpaceType == SpaceType.Continuous)
{
m_LastActionBuffer = new ActionBuffers(actions, Array.Empty<int>());
return ref m_LastActionBuffer;
}
m_LastActionBuffer = ActionBuffers.FromDiscreteActions(actions);
return ref m_LastActionBuffer;
}
public void Dispose()

46
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs


actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes,
actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions);
manager.UpdateActions(new[]
{ 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f }, Array.Empty<int>());
manager.UpdateActions(new ActionBuffers(new[]
{ 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f }, Array.Empty<int>()));
Assert.IsTrue(12 == manager.StoredContinuousActions.Length);
Assert.IsTrue(0 == manager.StoredDiscreteActions.Length);
Assert.IsTrue(12 == manager.StoredActions.ContinuousActions.Length);
Assert.IsTrue(0 == manager.StoredActions.DiscreteActions.Length);
}
[Test]

actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes,
actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions);
manager.UpdateActions(Array.Empty<float>(),
new[] { 0, 1, 2, 3, 4, 5, 6});
manager.UpdateActions(new ActionBuffers(Array.Empty<float>(),
new[] { 0, 1, 2, 3, 4, 5, 6}));
Assert.IsTrue(0 == manager.StoredContinuousActions.Length);
Assert.IsTrue(7 == manager.StoredDiscreteActions.Length);
Assert.IsTrue(0 == manager.StoredActions.ContinuousActions.Length);
Assert.IsTrue(7 == manager.StoredActions.DiscreteActions.Length);
}
[Test]

manager.Add(actuator2);
var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5, 6};
manager.UpdateActions(Array.Empty<float>(),
discreteActionBuffer);
manager.UpdateActions(new ActionBuffers(Array.Empty<float>(),
discreteActionBuffer));
manager.ExecuteActions();
var actuator1Actions = actuator1.LastActionBuffer.DiscreteActions;

manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
manager.UpdateActions(continuousActionBuffer,
Array.Empty<int>());
manager.UpdateActions(new ActionBuffers(continuousActionBuffer,
Array.Empty<int>()));
manager.ExecuteActions();
var actuator1Actions = actuator1.LastActionBuffer.ContinuousActions;

manager.Add(actuator1);
manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
manager.UpdateActions(continuousActionBuffer,
Array.Empty<int>());
manager.UpdateActions(new ActionBuffers(continuousActionBuffer,
Array.Empty<int>()));
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer));
Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(continuousActionBuffer));
}
[Test]

manager.Add(actuator1);
manager.Add(actuator2);
var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5};
manager.UpdateActions(Array.Empty<float>(),
discreteActionBuffer);
manager.UpdateActions(new ActionBuffers(Array.Empty<float>(),
discreteActionBuffer));
Debug.Log(manager.StoredDiscreteActions);
Debug.Log(manager.StoredActions.DiscreteActions);
Assert.IsTrue(manager.StoredDiscreteActions.SequenceEqual(discreteActionBuffer));
Assert.IsTrue(manager.StoredActions.DiscreteActions.SequenceEqual(discreteActionBuffer));
}
[Test]

manager.Add(actuator1);
manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
manager.UpdateActions(continuousActionBuffer,
Array.Empty<int>());
manager.UpdateActions(new ActionBuffers(continuousActionBuffer,
Array.Empty<int>()));
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer));
Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(continuousActionBuffer));
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(new[] { 0f, 0f, 0f, 0f, 0f, 0f}));
Assert.IsTrue(manager.StoredActions.ContinuousActions.SequenceEqual(new[] { 0f, 0f, 0f, 0f, 0f, 0f}));
}
[Test]

3
com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs


using NUnit.Framework;
using Unity.MLAgents.Actuators;
using UnityEngine;
using Unity.MLAgents.Policies;

public class BehaviorParameterTests
{
static void DummyHeuristic(float[] actionsOut)
static void DummyHeuristic(in ActionBuffers actionsOut)
{
// No-op
}

4
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


using NUnit.Framework;
using System.Reflection;
using System.Collections.Generic;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
using Unity.MLAgents.Policies;

{
public Action OnRequestDecision;
ObservationWriter m_ObsWriter = new ObservationWriter();
static ActionBuffers s_EmptyActionBuffers = new ActionBuffers(Array.Empty<float>(), Array.Empty<int>());
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
{
foreach (var sensor in sensors)

OnRequestDecision?.Invoke();
}
public float[] DecideAction() { return new float[0]; }
public ref readonly ActionBuffers DecideAction() { return ref s_EmptyActionBuffers; }
public void Dispose() {}
}

38
com.unity.ml-agents/Runtime/Agent.deprecated.cs


using System;
using UnityEngine;
namespace Unity.MLAgents
{
public partial class Agent
{
public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
{
}
/// <summary>
/// This method passes in a float array that is to be populated with actions. The actions
/// </summary>
/// <param name="actionsOut"></param>
public virtual void Heuristic(float[] actionsOut)
{
Debug.LogWarning("Heuristic method called but not implemented. Returning placeholder actions.");
Array.Clear(actionsOut, 0, actionsOut.Length);
}
public virtual void OnActionReceived(float[] vectorAction) {}
/// <summary>
/// Returns the last action that was decided on by the Agent.
/// </summary>
/// <returns>
/// The last action that was decided by the Agent (or null if no decision has been made).
/// </returns>
/// <seealso cref="OnActionReceived(float[])"/>
// [Obsolete("GetAction has been deprecated, please use GetStoredContinuousActions, Or GetStoredDiscreteActions.")]
public float[] GetAction()
{
return m_Info.storedVectorActions;
}
}
}

3
com.unity.ml-agents/Runtime/Agent.deprecated.cs.meta


fileFormatVersion: 2
guid: 9650a482703b47db8cd7fb2df8caa1bf
timeCreated: 1595613441

11
com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs.meta


fileFormatVersion: 2
guid: 2e2810ee6c8c64fb39abdf04b5d17f50
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

142
com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs


using NUnit.Framework;
using Unity.MLAgents.Policies;
namespace Unity.MLAgents.Tests
{
public class EditModeTestActionMasker
{
[Test]
public void Contruction()
{
var bp = new BrainParameters();
var masker = new DiscreteActionMasker(bp);
Assert.IsNotNull(masker);
}
[Test]
public void FailsWithContinuous()
{
var bp = new BrainParameters();
bp.VectorActionSpaceType = SpaceType.Continuous;
bp.VectorActionSize = new[] {4};
var masker = new DiscreteActionMasker(bp);
masker.SetMask(0, new[] {0});
Assert.Catch<UnityAgentsException>(() => masker.GetMask());
}
[Test]
public void NullMask()
{
var bp = new BrainParameters();
bp.VectorActionSpaceType = SpaceType.Discrete;
var masker = new DiscreteActionMasker(bp);
var mask = masker.GetMask();
Assert.IsNull(mask);
}
[Test]
public void FirstBranchMask()
{
var bp = new BrainParameters();
bp.VectorActionSpaceType = SpaceType.Discrete;
bp.VectorActionSize = new[] {4, 5, 6};
var masker = new DiscreteActionMasker(bp);
var mask = masker.GetMask();
Assert.IsNull(mask);
masker.SetMask(0, new[] {1, 2, 3});
mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsTrue(mask[1]);
Assert.IsTrue(mask[2]);
Assert.IsTrue(mask[3]);
Assert.IsFalse(mask[4]);
Assert.AreEqual(mask.Length, 15);
}
[Test]
public void SecondBranchMask()
{
var bp = new BrainParameters
{
VectorActionSpaceType = SpaceType.Discrete,
VectorActionSize = new[] { 4, 5, 6 }
};
var masker = new DiscreteActionMasker(bp);
masker.SetMask(1, new[] {1, 2, 3});
var mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsFalse(mask[4]);
Assert.IsTrue(mask[5]);
Assert.IsTrue(mask[6]);
Assert.IsTrue(mask[7]);
Assert.IsFalse(mask[8]);
Assert.IsFalse(mask[9]);
}
[Test]
public void MaskReset()
{
var bp = new BrainParameters
{
VectorActionSpaceType = SpaceType.Discrete,
VectorActionSize = new[] { 4, 5, 6 }
};
var masker = new DiscreteActionMasker(bp);
masker.SetMask(1, new[] {1, 2, 3});
masker.ResetMask();
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{
Assert.IsFalse(mask[i]);
}
}
[Test]
public void ThrowsError()
{
var bp = new BrainParameters
{
VectorActionSpaceType = SpaceType.Discrete,
VectorActionSize = new[] { 4, 5, 6 }
};
var masker = new DiscreteActionMasker(bp);
Assert.Catch<UnityAgentsException>(
() => masker.SetMask(0, new[] {5}));
Assert.Catch<UnityAgentsException>(
() => masker.SetMask(1, new[] {5}));
masker.SetMask(2, new[] {5});
Assert.Catch<UnityAgentsException>(
() => masker.SetMask(3, new[] {1}));
masker.GetMask();
masker.ResetMask();
masker.SetMask(0, new[] {0, 1, 2, 3});
Assert.Catch<UnityAgentsException>(
() => masker.GetMask());
}
[Test]
public void MultipleMaskEdit()
{
var bp = new BrainParameters();
bp.VectorActionSpaceType = SpaceType.Discrete;
bp.VectorActionSize = new[] {4, 5, 6};
var masker = new DiscreteActionMasker(bp);
masker.SetMask(0, new[] {0, 1});
masker.SetMask(0, new[] {3});
masker.SetMask(2, new[] {1});
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{
if ((i == 0) || (i == 1) || (i == 3) || (i == 10))
{
Assert.IsTrue(mask[i]);
}
else
{
Assert.IsFalse(mask[i]);
}
}
}
}
}
正在加载...
取消
保存