using System;
using System.Collections.Generic;
using System.Linq;
using MLAgents.Policies;
namespace MLAgents
{
///
/// The DiscreteActionMasker class represents a set of masked (disallowed) actions and
/// provides utilities for setting and retrieving them.
///
///
/// Agents that take discrete actions can explicitly indicate that specific actions
/// are not allowed at a point in time. This enables the agent to indicate that some actions
/// may be illegal. For example, if an agent is adjacent to a wall or other obstacle
/// you could mask any actions that direct the agent to move into the blocked space.
///
public class DiscreteActionMasker
{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.
int[] m_StartingActionIndices;
bool[] m_CurrentMask;
readonly BrainParameters m_BrainParameters;
internal DiscreteActionMasker(BrainParameters brainParameters)
{
m_BrainParameters = brainParameters;
}
///
/// Modifies an action mask for discrete control agents.
///
///
/// When used, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndices correspond
/// to the action options the agent will be unable to perform.
///
/// See [Agents - Actions] for more information on masking actions.
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/0.15.1/docs/Learning-Environment-Design-Agents.md#actions
///
/// The branch for which the actions will be masked.
/// The indices of the masked actions.
public void SetMask(int branch, IEnumerable actionIndices)
{
// If the branch does not exist, raise an error
if (branch >= m_BrainParameters.VectorActionSize.Length)
throw new UnityAgentsException(
"Invalid Action Masking : Branch " + branch + " does not exist.");
var totalNumberActions = m_BrainParameters.VectorActionSize.Sum();
// By default, the masks are null. If we want to specify a new mask, we initialize
// the actionMasks with trues.
if (m_CurrentMask == null)
{
m_CurrentMask = new bool[totalNumberActions];
}
// If this is the first time the masked actions are used, we generate the starting
// indices for each branch.
if (m_StartingActionIndices == null)
{
m_StartingActionIndices = Utilities.CumSum(m_BrainParameters.VectorActionSize);
}
// Perform the masking
foreach (var actionIndex in actionIndices)
{
if (actionIndex >= m_BrainParameters.VectorActionSize[branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
m_CurrentMask[actionIndex + m_StartingActionIndices[branch]] = true;
}
}
///
/// Get the current mask for an agent.
///
/// A mask for the agent. A boolean array of length equal to the total number of
/// actions.
internal bool[] GetMask()
{
if (m_CurrentMask != null)
{
AssertMask();
}
return m_CurrentMask;
}
///
/// Makes sure that the current mask is usable.
///
void AssertMask()
{
// Action Masks can only be used in Discrete Control.
if (m_BrainParameters.VectorActionSpaceType != SpaceType.Discrete)
{
throw new UnityAgentsException(
"Invalid Action Masking : Can only set action mask for Discrete Control.");
}
var numBranches = m_BrainParameters.VectorActionSize.Length;
for (var branchIndex = 0; branchIndex < numBranches; branchIndex++)
{
if (AreAllActionsMasked(branchIndex))
{
throw new UnityAgentsException(
"Invalid Action Masking : All the actions of branch " + branchIndex +
" are masked.");
}
}
}
///
/// Resets the current mask for an agent.
///
internal void ResetMask()
{
if (m_CurrentMask != null)
{
Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length);
}
}
///
/// Checks if all the actions in the input branch are masked.
///
/// The index of the branch to check.
/// True if all the actions of the branch are masked.
bool AreAllActionsMasked(int branch)
{
if (m_CurrentMask == null)
{
return false;
}
var start = m_StartingActionIndices[branch];
var end = m_StartingActionIndices[branch + 1];
for (var i = start; i < end; i++)
{
if (!m_CurrentMask[i])
{
return false;
}
}
return true;
}
}
}