浏览代码

non-IEnumerable interface for action masking (#5060)

/v2-staging-rebase
GitHub 4 年前
当前提交
4863475c
共有 12 个文件被更改,包括 149 次插入96 次删除
  1. 8
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  2. 69
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
  3. 5
      com.unity.ml-agents/CHANGELOG.md
  4. 29
      com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs
  5. 2
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
  6. 25
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
  7. 2
      com.unity.ml-agents/Runtime/Agent.cs
  8. 45
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs
  9. 6
      com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
  10. 5
      com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs
  11. 24
      docs/Learning-Environment-Design-Agents.md
  12. 25
      docs/Migrating.md

8
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


if (positionX == 0)
{
actionMask.WriteMask(0, new[] { k_Left });
actionMask.SetActionEnabled(0, k_Left, false);
actionMask.WriteMask(0, new[] { k_Right });
actionMask.SetActionEnabled(0, k_Right, false);
actionMask.WriteMask(0, new[] { k_Down });
actionMask.SetActionEnabled(0, k_Down, false);
actionMask.WriteMask(0, new[] { k_Up });
actionMask.SetActionEnabled(0, k_Up, false);
}
}
}

69
com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs


/// <param name="agent"></param>
/// <param name="name"></param>
public Match3Actuator(AbstractBoard board,
bool forceHeuristic,
int seed,
Agent agent,
string name)
bool forceHeuristic,
int seed,
Agent agent,
string name)
{
m_Board = board;
m_Rows = board.Rows;

/// <inheritdoc/>
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
const int branch = 0;
bool foundValidMove = false;
actionMask.WriteMask(0, InvalidMoveIndices());
}
}
var numMoves = m_Board.NumMoves();
/// <inheritdoc/>
public string Name { get; }
var currentMove = Move.FromMoveIndex(0, m_Board.Rows, m_Board.Columns);
for (var i = 0; i < numMoves; i++)
{
if (m_Board.IsMoveValid(currentMove))
{
foundValidMove = true;
}
else
{
actionMask.SetActionEnabled(branch, i, false);
}
currentMove.Next(m_Board.Rows, m_Board.Columns);
}
/// <inheritdoc/>
public void ResetData()
{
}
/// <inheritdoc/>
public BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.Match3Actuator;
}
IEnumerable<int> InvalidMoveIndices()
{
var numValidMoves = m_Board.NumMoves();
foreach (var move in m_Board.InvalidMoves())
{
numValidMoves--;
if (numValidMoves == 0)
if (!foundValidMove)
{
// If all the moves are invalid and we mask all the actions out, this will cause an assert
// later on in IDiscreteActionMask. Instead, fire a callback to the user if they provided one,

"an invalid move will be passed to AbstractBoard.MakeMove()."
);
}
// This means the last move won't be returned as an invalid index.
yield break;
actionMask.SetActionEnabled(branch, numMoves - 1, true);
yield return move.MoveIndex;
/// <inheritdoc/>
public string Name { get; }
/// <inheritdoc/>
public void ResetData()
{
}
/// <inheritdoc/>
public BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.Match3Actuator;
}
public void Heuristic(in ActionBuffers actionsOut)
{
var discreteActions = actionsOut.DiscreteActions;

var bestMoveIndex = 0;
var bestMovePoints = -1;
var numMovesAtCurrentScore = 0;

5
com.unity.ml-agents/CHANGELOG.md


### Major Changes
#### com.unity.ml-agents (C#)
- Some methods previously marked as `Obsolete` have been removed. If you were using these methods, you need to replace them with their supported counterpart.
- The interface for disabling discrete actions in `IDiscreteActionMask` has changed.
`WriteMask(int branch, IEnumerable<int> actionIndices)` was replaced with
`SetActionEnabled(int branch, int actionIndex, bool isEnabled)`. See the
[Migration Guide](https://github.com/Unity-Technologies/ml-agents/blob/release_14_docs/docs/Migrating.md) for more
details. (#5060)
#### ml-agents / ml-agents-envs / gym-unity (Python)
### Minor Changes

29
com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs


}
/// <inheritdoc/>
public void WriteMask(int branch, IEnumerable<int> actionIndices)
public void SetActionEnabled(int branch, int actionIndex, bool isEnabled)
// Perform the masking
foreach (var actionIndex in actionIndices)
#if DEBUG
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
#if DEBUG
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = true;
}
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = !isEnabled;
}
void LazyInitialize()

}
}
/// <inheritdoc/>
public bool[] GetMask()
/// <summary>
/// Get the current mask for an agent.
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
internal bool[] GetMask()
{
#if DEBUG
if (m_CurrentMask != null)

/// <summary>
/// Resets the current mask for an agent.
/// </summary>
public void ResetMask()
internal void ResetMask()
{
if (m_CurrentMask != null)
{

2
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs


/// </param>
/// <remarks>
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask"/>.
/// action by masking it with <see cref="IDiscreteActionMask.SetActionEnabled"/>.
///
/// See [Agents - Actions] for more information on masking actions.
///

25
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs


public interface IDiscreteActionMask
{
/// <summary>
/// Modifies an action mask for discrete control agents.
/// Set whether or not the action index for the given branch is allowed.
/// <remarks>
/// When used, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndices correspond
/// By default, all discrete actions are allowed.
/// If isEnabled is false, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndex correspond
/// to the action options the agent will be unable to perform.
///
/// See [Agents - Actions] for more information on masking actions.

/// <param name="branch">The branch for which the actions will be masked.</param>
/// <param name="actionIndices">The indices of the masked actions.</param>
void WriteMask(int branch, IEnumerable<int> actionIndices);
/// <summary>
/// Get the current mask for an agent.
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
bool[] GetMask();
/// <summary>
/// Resets the current mask for an agent.
/// </summary>
void ResetMask();
/// <param name="actionIndex">Index of the action</param>
/// <param name="isEnabled">Whether the action is allowed or now.</param>
void SetActionEnabled(int branch, int actionIndex, bool isEnabled);
}
}

2
com.unity.ml-agents/Runtime/Agent.cs


/// </param>
/// <remarks>
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask(int, IEnumerable{int})"/>.
/// action by masking it with <see cref="IDiscreteActionMask.SetActionEnabled"/>.
///
/// See [Agents - Actions] for more information on masking actions.
///

45
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs


var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
var mask = masker.GetMask();
Assert.IsNull(mask);
masker.WriteMask(0, new[] { 1, 2, 3 });
masker.SetActionEnabled(0, 1, false);
masker.SetActionEnabled(0, 2, false);
masker.SetActionEnabled(0, 3, false);
mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsTrue(mask[1]);

}
[Test]
public void CanOverwriteMask()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
masker.SetActionEnabled(0, 1, false);
var mask = masker.GetMask();
Assert.IsTrue(mask[1]);
masker.SetActionEnabled(0, 1, true);
Assert.IsFalse(mask[1]);
}
[Test]
masker.WriteMask(1, new[] { 1, 2, 3 });
masker.SetActionEnabled(1, 1, false);
masker.SetActionEnabled(1, 2, false);
masker.SetActionEnabled(1, 3, false);
var mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsFalse(mask[4]);

{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
masker.WriteMask(1, new[] { 1, 2, 3 });
masker.SetActionEnabled(1, 1, false);
masker.SetActionEnabled(1, 2, false);
masker.SetActionEnabled(1, 3, false);
masker.ResetMask();
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)

var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
Assert.Catch<UnityAgentsException>(
() => masker.WriteMask(0, new[] { 5 }));
() => masker.SetActionEnabled(0, 5, false));
() => masker.WriteMask(1, new[] { 5 }));
masker.WriteMask(2, new[] { 5 });
() => masker.SetActionEnabled(1, 5, false));
masker.SetActionEnabled(2, 5, false);
() => masker.WriteMask(3, new[] { 1 }));
() => masker.SetActionEnabled(3, 1, false));
masker.WriteMask(0, new[] { 0, 1, 2, 3 });
masker.SetActionEnabled(0, 0, false);
masker.SetActionEnabled(0, 1, false);
masker.SetActionEnabled(0, 2, false);
masker.SetActionEnabled(0, 3, false);
Assert.Catch<UnityAgentsException>(
() => masker.GetMask());
}

{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 4, 5, 6 }), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] { actuator1 }, 15, 3);
masker.WriteMask(0, new[] { 0, 1 });
masker.WriteMask(0, new[] { 3 });
masker.WriteMask(2, new[] { 1 });
masker.SetActionEnabled(0, 0, false);
masker.SetActionEnabled(0, 1, false);
masker.SetActionEnabled(0, 3, false);
masker.SetActionEnabled(2, 1, false);
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{

6
com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs


public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
actionMask.WriteMask(i, Masks[i]);
foreach (var actionIndex in Masks[i])
{
actionMask.SetActionEnabled(i, actionIndex, false);
}
}
}

5
com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs


public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
actionMask.WriteMask(Branch, Mask);
foreach (var actionIndex in Mask)
{
actionMask.SetActionEnabled(Branch, actionIndex, false);
}
}
public void Heuristic(in ActionBuffers actionBuffersOut)

24
docs/Learning-Environment-Design-Agents.md


impossible for the next decision. When the Agent is controlled by a neural
network, the Agent will be unable to perform the specified action. Note that
when the Agent is controlled by its Heuristic, the Agent will still be able to
decide to perform the masked action. In order to mask an action, override the
`Agent.WriteDiscreteActionMask()` virtual method, and call
`WriteMask()` on the provided `IDiscreteActionMask`:
decide to perform the masked action. In order to disallow an action, override
the `Agent.WriteDiscreteActionMask()` virtual method, and call
`SetActionEnabled()` on the provided `IDiscreteActionMask`:
actionMask.WriteMask(branch, actionIndices);
actionMask.SetActionEnabled(branch, actionIndex, isEnabled);
- `branch` is the index (starting at 0) of the branch on which you want to mask
the action
- `actionIndices` is a list of `int` corresponding to the indices of the actions
that the Agent **cannot** perform.
- `branch` is the index (starting at 0) of the branch on which you want to
allow or disallow the action
- `actionIndex` is the index of the action that you want to allow or disallow.
- `isEnabled` is a bool indicating whether the action should be allowed or now.
nothing"_ or _"change weapon"_ for his next decision (since action index 1 and 2
nothing"_ or _"change weapon"_ for their next decision (since action index 1 and 2
WriteMask(0, new int[2]{1,2});
actionMask.SetActionEnabled(0, 1, false);
actionMask.SetActionEnabled(0, 2, false);
- You can call `WriteMask` multiple times if you want to put masks on multiple
- You can call `SetActionEnabled` multiple times if you want to put masks on multiple
- At each step, the state of an action is reset and enabled by default.
- You cannot mask all the actions of a branch.
- You cannot mask actions in continuous control.

25
docs/Migrating.md


# Migrating
## Migrating the package to version 2.0
- If you used any of the APIs that were deprecated before version 2.0, you need to use their replacement. These deprecated APIs have been removed. See the migration steps bellow for specific API replacements.
### IDiscreteActionMask changes
- The interface for disabling specific discrete actions has changed. `IDiscreteActionMask.WriteMask()` was removed,
and replaced with `SetActionEnabled()`. Instead of returning an IEnumerable with indices to disable, you can
now call `SetActionEnabled` for each index to disable (or enable). As an example, if you overrode
`Agent.WriteDiscreteActionMask()` with something that looked like:
```csharp
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
var branch = 2;
var actionsToDisable = new[] {1, 3};
actionMask.WriteMask(branch, actionsToDisable);
}
```
the equivalent code would now be
```csharp
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
var branch = 2;
actionMask.SetActionEnabled(branch, 1, false);
actionMask.SetActionEnabled(branch, 3, false);
}
```
## Migrating to Release 13
### Implementing IHeuristic in your IActuator implementations

正在加载...
取消
保存