using System;
using System.Collections.Generic;
using System.Linq;
namespace MLAgents
{
public class ActionMasker
{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.
private int[] _startingActionIndices;
private bool[] _currentMask;
private readonly BrainParameters _brainParameters;
public ActionMasker(BrainParameters brainParameters)
{
this._brainParameters = brainParameters;
}
///
/// Modifies an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
///
/// The branch for which the actions will be masked
/// The indices of the masked actions
public void SetActionMask(int branch, IEnumerable actionIndices)
{
// If the branch does not exist, raise an error
if (branch >= _brainParameters.vectorActionSize.Length )
throw new UnityAgentsException(
"Invalid Action Masking : Branch "+branch+" does not exist.");
int totalNumberActions = _brainParameters.vectorActionSize.Sum();
// By default, the masks are null. If we want to specify a new mask, we initialize
// the actionMasks with trues.
if (_currentMask == null)
{
_currentMask = new bool[totalNumberActions];
}
// If this is the first time the masked actions are used, we generate the starting
// indices for each branch.
if (_startingActionIndices == null)
{
_startingActionIndices = CreateActionStartinIndices();
}
// Perform the masking
foreach (var actionIndex in actionIndices)
{
if (actionIndex >= _brainParameters.vectorActionSize[branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
_currentMask[actionIndex + _startingActionIndices[branch]] = true;
}
}
///
/// Get the current mask for an agent
///
/// A mask for the agent. A boolean array of length equal to the total number of
/// actions.
public bool[] GetMask()
{
if (_currentMask != null)
{
AssertMask();
}
return _currentMask;
}
///
/// Makes sure that the current mask is usable.
///
private void AssertMask()
{
// Action Masks can only be used in Discrete Control.
if (_brainParameters.vectorActionSpaceType != SpaceType.discrete)
{
throw new UnityAgentsException(
"Invalid Action Masking : Can only set action mask for Discrete Control.");
}
var numBranches = _brainParameters.vectorActionSize.Length;
for (var branchIndex = 0 ; branchIndex < numBranches; branchIndex++ )
{
if (AreAllActionsMasked(branchIndex))
{
throw new UnityAgentsException(
"Invalid Action Masking : All the actions of branch " + branchIndex +
" are masked.");
}
}
}
///
/// Resets the current mask for an agent
///
public void ResetMask()
{
if (_currentMask != null)
{
Array.Clear(_currentMask, 0, _currentMask.Length);
}
}
///
/// Generates an array containing the starting indicies of each branch in the vector action
/// Makes a cumulative sum.
///
///
private int[] CreateActionStartinIndices()
{
var vectorActionSize = _brainParameters.vectorActionSize;
var runningSum = 0;
var result = new int[vectorActionSize.Length + 1];
for (var actionIndex = 0;
actionIndex < vectorActionSize.Length; actionIndex++)
{
runningSum += vectorActionSize[actionIndex];
result[actionIndex + 1] = runningSum;
}
return result;
}
///
/// Checks if all the actions in the input branch are masked
///
/// The index of the branch to check
/// True if all the actions of the branch are masked
private bool AreAllActionsMasked(int branch)
{
if (_currentMask == null)
{
return false;
}
var start = _startingActionIndices[branch];
var end = _startingActionIndices[branch + 1];
for (var i = start; i < end; i++)
{
if (!_currentMask[i])
{
return false;
}
}
return true;
}
}
}