浏览代码

[Renaming] SetActionMask -> SetDiscreteActionMask + added the virtual method CollectDiscreteActionMasks (#3525)

* Code edits

* Modified the markdowns

* Update com.unity.ml-agents/CHANGELOG.md

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Update docs/Learning-Environment-Design-Agents.md

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Update docs/Learning-Environment-Design-Agents.md

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Renaming files and methods

* Addressing comments

* Update docs/Learning-Environment-Design-Agents.md

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

Co-authored-by: Chris Elion <celion@gmail.com>
/asymm-envs
GitHub 5 年前
当前提交
9a371b17
共有 9 个文件被更改,包括 214 次插入286 次删除
  1. 53
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  2. 3
      com.unity.ml-agents/CHANGELOG.md
  3. 59
      com.unity.ml-agents/Runtime/Agent.cs
  4. 40
      com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs
  5. 14
      docs/Learning-Environment-Design-Agents.md
  6. 7
      docs/Migrating.md
  7. 144
      com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
  8. 180
      com.unity.ml-agents/Runtime/ActionMasker.cs
  9. 0
      /com.unity.ml-agents/Runtime/DiscreteActionMasker.cs.meta

53
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


{
}
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
// There are no numeric observations to collect as this environment uses visual
// observations.
SetMask(actionMasker);
}
}
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var maxPosition = (int)Academy.Instance.FloatProperties.GetPropertyWithDefault("gridSize", 5f) - 1;
/// <summary>
/// Applies the mask for the agents action to disallow unnecessary actions.
/// </summary>
void SetMask(ActionMasker actionMasker)
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var maxPosition = (int)Academy.Instance.FloatProperties.GetPropertyWithDefault("gridSize", 5f) - 1;
if (positionX == 0)
{
actionMasker.SetMask(0, new int[]{ k_Left});
}
if (positionX == 0)
{
actionMasker.SetActionMask(k_Left);
}
if (positionX == maxPosition)
{
actionMasker.SetMask(0, new int[]{k_Right});
}
if (positionX == maxPosition)
{
actionMasker.SetActionMask(k_Right);
}
if (positionZ == 0)
{
actionMasker.SetActionMask(k_Down);
}
if (positionZ == 0)
{
actionMasker.SetMask(0, new int[]{k_Down});
}
if (positionZ == maxPosition)
{
actionMasker.SetActionMask(k_Up);
if (positionZ == maxPosition)
{
actionMasker.SetMask(0, new int[]{k_Up});
}
}
}

3
com.unity.ml-agents/CHANGELOG.md


## [Unreleased]
### Major Changes
- Agent.CollectObservations now takes a VectorSensor argument. It was also overloaded to optionally take an ActionMasker argument. (#3352, #3389)
- `Agent.CollectObservations` now takes a VectorSensor argument. (#3352, #3389)
- Added `Agent.CollectDiscreteActionMasks` virtual method with a `DiscreteActionMasker` argument to specify which discrete actions are unavailable to the Agent. (#3525)
- Beta support for ONNX export was added. If the `tf2onnx` python package is installed, models will be saved to `.onnx` as well as `.nn` format.
Note that Barracuda 0.6.0 or later is required to import the `.onnx` files properly
- Multi-GPU training and the `--multi-gpu` option has been removed temporarily. (#3345)

59
com.unity.ml-agents/Runtime/Agent.cs


/// an Agent. An agent produces observations and takes actions in the
/// environment. Observations are determined by the cameras attached
/// to the agent in addition to the vector observations implemented by the
/// user in <see cref="Agent.CollectObservations(VectorSensor)"/> or
/// <see cref="Agent.CollectObservations(VectorSensor, ActionMasker)"/>.
/// user in <see cref="Agent.CollectObservations(VectorSensor)"/>.
/// On the other hand, actions are determined by decisions produced by a Policy.
/// Currently, this class is expected to be extended to implement the desired agent behavior.
/// </summary>

bool m_Initialized;
/// Keeps track of the actions that are masked at each step.
ActionMasker m_ActionMasker;
DiscreteActionMasker m_ActionMasker;
/// <summary>
/// Set of DemonstrationWriters that the Agent will write its step information to.

void ResetData()
{
var param = m_PolicyFactory.brainParameters;
m_ActionMasker = new ActionMasker(param);
m_ActionMasker = new DiscreteActionMasker(param);
// If we haven't initialized vectorActions, initialize to 0. This should only
// happen during the creation of the Agent. In subsequent episodes, vectorAction
// should stay the previous action before the Done(), so that it is properly recorded.

UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{
CollectObservations(collectObservationsSensor, m_ActionMasker);
CollectObservations(collectObservationsSensor);
}
using (TimerStack.Instance.Scoped("CollectDiscreteActionMasks"))
{
if (m_PolicyFactory.brainParameters.vectorActionSpaceType == SpaceType.Discrete)
{
CollectDiscreteActionMasks(m_ActionMasker);
}
}
m_Info.actionMasks = m_ActionMasker.GetMask();

}
/// <summary>
/// Collects the vector observations of the agent alongside the masked actions.
/// The agent observation describes the current environment from the
/// perspective of the agent.
/// Collects the masks for discrete actions.
/// When using discrete actions, the agent will not perform the masked action.
/// <param name="sensor">
/// The vector observations for the agent.
/// </param>
/// The masked actions for the agent.
/// The action masker for the agent.
/// An agents observation is any environment information that helps
/// the Agent achieve its goal. For example, for a fighting Agent, its
/// observation could include distances to friends or enemies, or the
/// current level of ammunition at its disposal.
/// Recall that an Agent may attach vector or visual observations.
/// Vector observations are added by calling the provided helper methods
/// on the VectorSensor input:
/// - <see cref="VectorSensor.AddObservation(int)"/>
/// - <see cref="VectorSensor.AddObservation(float)"/>
/// - <see cref="VectorSensor.AddObservation(Vector3)"/>
/// - <see cref="VectorSensor.AddObservation(Vector2)"/>
/// - <see cref="VectorSensor.AddObservation(Quaternion)"/>
/// - <see cref="VectorSensor.AddObservation(bool)"/>
/// - <see cref="VectorSensor.AddObservation(IEnumerable{float})"/>
/// - <see cref="VectorSensor.AddOneHotObservation(int, int)"/>
/// Depending on your environment, any combination of these helpers can
/// be used. They just need to be used in the exact same order each time
/// this method is called and the resulting size of the vector observation
/// needs to match the vectorObservationSize attribute of the linked Brain.
/// Visual observations are implicitly added from the cameras attached to
/// the Agent.
/// action by masking it. You can call the following method on the ActionMasker
/// input :
/// - <see cref="ActionMasker.SetActionMask(int)"/>
/// - <see cref="ActionMasker.SetActionMask(int, int)"/>
/// - <see cref="ActionMasker.SetActionMask(int, IEnumerable{int})"/>
/// - <see cref="ActionMasker.SetActionMask(IEnumerable{int})"/>
/// The branch input is the index of the action, actionIndices are the indices of the
/// invalid options for that action.
/// action by masking it with <see cref="DiscreteActionMasker.SetMask(int, IEnumerable{int})"/>
public virtual void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
public virtual void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
CollectObservations(sensor);
}
/// <summary>

40
com.unity.ml-agents/Tests/Editor/EditModeTestActionMasker.cs


public void Contruction()
{
var bp = new BrainParameters();
var masker = new ActionMasker(bp);
var masker = new DiscreteActionMasker(bp);
Assert.IsNotNull(masker);
}

var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Continuous;
bp.vectorActionSize = new[] {4};
var masker = new ActionMasker(bp);
masker.SetActionMask(0, new[] {0});
var masker = new DiscreteActionMasker(bp);
masker.SetMask(0, new[] {0});
Assert.Catch<UnityAgentsException>(() => masker.GetMask());
}

var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Discrete;
var masker = new ActionMasker(bp);
var masker = new DiscreteActionMasker(bp);
var mask = masker.GetMask();
Assert.IsNull(mask);
}

var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Discrete;
bp.vectorActionSize = new[] {4, 5, 6};
var masker = new ActionMasker(bp);
var masker = new DiscreteActionMasker(bp);
masker.SetActionMask(0, new[] {1, 2, 3});
masker.SetMask(0, new[] {1, 2, 3});
mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsTrue(mask[1]);

vectorActionSpaceType = SpaceType.Discrete,
vectorActionSize = new[] { 4, 5, 6 }
};
var masker = new ActionMasker(bp);
masker.SetActionMask(1, new[] {1, 2, 3});
var masker = new DiscreteActionMasker(bp);
masker.SetMask(1, new[] {1, 2, 3});
var mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsFalse(mask[4]);

vectorActionSpaceType = SpaceType.Discrete,
vectorActionSize = new[] { 4, 5, 6 }
};
var masker = new ActionMasker(bp);
masker.SetActionMask(1, new[] {1, 2, 3});
var masker = new DiscreteActionMasker(bp);
masker.SetMask(1, new[] {1, 2, 3});
masker.ResetMask();
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)

vectorActionSpaceType = SpaceType.Discrete,
vectorActionSize = new[] { 4, 5, 6 }
};
var masker = new ActionMasker(bp);
var masker = new DiscreteActionMasker(bp);
() => masker.SetActionMask(0, new[] {5}));
() => masker.SetMask(0, new[] {5}));
() => masker.SetActionMask(1, new[] {5}));
masker.SetActionMask(2, new[] {5});
() => masker.SetMask(1, new[] {5}));
masker.SetMask(2, new[] {5});
() => masker.SetActionMask(3, new[] {1}));
() => masker.SetMask(3, new[] {1}));
masker.SetActionMask(0, new[] {0, 1, 2, 3});
masker.SetMask(0, new[] {0, 1, 2, 3});
Assert.Catch<UnityAgentsException>(
() => masker.GetMask());
}

var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.Discrete;
bp.vectorActionSize = new[] {4, 5, 6};
var masker = new ActionMasker(bp);
masker.SetActionMask(0, new[] {0, 1});
masker.SetActionMask(0, new[] {3});
masker.SetActionMask(2, new[] {1});
var masker = new DiscreteActionMasker(bp);
masker.SetMask(0, new[] {0, 1});
masker.SetMask(0, new[] {3});
masker.SetMask(2, new[] {1});
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{

14
docs/Learning-Environment-Design-Agents.md


neural network, the Agent will be unable to perform the specified action. Note
that when the Agent is controlled by its Heuristic, the Agent will
still be able to decide to perform the masked action. In order to mask an
action, call the method `SetActionMask` on the optional `ActionMasker` argument of the `CollectObservation` method :
action, override the `Agent.CollectDiscreteActionMasks()` virtual method, and call `DiscreteActionMasker.SetMask()` in it:
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker){
actionMasker.SetActionMask(branch, actionIndices)
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker){
actionMasker.SetMask(branch, actionIndices)
}
```

the action
* `actionIndices` is a list of `int` or a single `int` corresponding to the
index of the action that the Agent cannot perform.
* `actionIndices` is a list of `int` corresponding to the
indices of the actions that the Agent cannot perform.
For example, if you have an Agent with 2 branches and on the first branch
(branch 0) there are 4 possible actions : _"do nothing"_, _"jump"_, _"shoot"_

```csharp
SetActionMask(0, new int[2]{1,2})
SetMask(0, new int[2]{1,2})
* You can call `SetActionMask` multiple times if you want to put masks on
* You can call `SetMask` multiple times if you want to put masks on
multiple branches.
* You cannot mask all the actions of a branch.
* You cannot mask actions in continuous control.

7
docs/Migrating.md


### Important changes
* The `Agent.CollectObservations()` virtual method now takes as input a `VectorSensor` sensor as argument. The `Agent.AddVectorObs()` methods were removed.
* The `SetActionMask` method must now be called on the optional `ActionMasker` argument of the `CollectObservations` method. (We now consider an action mask as a type of observation)
* The `SetMask` was renamed to `SetMask` method must now be called on the `DiscreteActionMasker` argument of the `CollectDiscreteActionMasks` virtual method.
* We consolidated our API for `DiscreteActionMasker`. `SetMask` takes two arguments : the branch index and the list of masked actions for that branch.
* The `SetActionMask` method must now be called on the optional `ActionMasker` argument of the `CollectObservations` method. (We now consider an action mask as a type of observation)
* The `SetMask` method must now be called on the `DiscreteActionMasker` argument of the `CollectDiscreteActionMasks` method.
* The method `GetStepCount()` on the Agent class has been replaced with the property getter `StepCount`
* The `--multi-gpu` option has been removed temporarily.

* Replace your calls to `SetActionMask` on your Agent to `ActionMasker.SetActionMask` in `CollectObservations`
* Replace your calls to `SetActionMask` on your Agent to `DiscreteActionMasker.SetActionMask` in `CollectDiscreteActionMasks`.
* If you call `RayPerceptionSensor.PerceiveStatic()` manually, add your inputs to a `RayPerceptionInput`. To get the previous float array output,
iterate through `RayPerceptionOutput.rayOutputs` and call `RayPerceptionOutput.RayOutput.ToFloatArray()`.
* Re-import all of your `*.NN` files to work with the updated Barracuda package.

144
com.unity.ml-agents/Runtime/DiscreteActionMasker.cs


using System;
using System.Collections.Generic;
using System.Linq;
namespace MLAgents
{
/// <summary>
/// Agents that take discrete actions can explicitly indicate that specific actions
/// are not allowed at a point in time. This enables the agent to indicate that some actions
/// may be illegal (e.g. the King in Chess taking a move to the left if it is already in the
/// left side of the board). This class represents the set of masked actions and provides
/// the utilities for setting and retrieving them.
/// </summary>
public class DiscreteActionMasker
{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.
int[] m_StartingActionIndices;
bool[] m_CurrentMask;
readonly BrainParameters m_BrainParameters;
internal DiscreteActionMasker(BrainParameters brainParameters)
{
m_BrainParameters = brainParameters;
}
/// <summary>
/// Modifies an action mask for discrete control agents. When used, the agent will not be
/// able to perform the actions passed as argument at the next decision for the specified
/// action branch. The actionIndices correspond to the action options the agent will
/// be unable to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndices">The indices of the masked actions</param>
public void SetMask(int branch, IEnumerable<int> actionIndices)
{
// If the branch does not exist, raise an error
if (branch >= m_BrainParameters.vectorActionSize.Length)
throw new UnityAgentsException(
"Invalid Action Masking : Branch " + branch + " does not exist.");
var totalNumberActions = m_BrainParameters.vectorActionSize.Sum();
// By default, the masks are null. If we want to specify a new mask, we initialize
// the actionMasks with trues.
if (m_CurrentMask == null)
{
m_CurrentMask = new bool[totalNumberActions];
}
// If this is the first time the masked actions are used, we generate the starting
// indices for each branch.
if (m_StartingActionIndices == null)
{
m_StartingActionIndices = Utilities.CumSum(m_BrainParameters.vectorActionSize);
}
// Perform the masking
foreach (var actionIndex in actionIndices)
{
if (actionIndex >= m_BrainParameters.vectorActionSize[branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
m_CurrentMask[actionIndex + m_StartingActionIndices[branch]] = true;
}
}
/// <summary>
/// Get the current mask for an agent
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
internal bool[] GetMask()
{
if (m_CurrentMask != null)
{
AssertMask();
}
return m_CurrentMask;
}
/// <summary>
/// Makes sure that the current mask is usable.
/// </summary>
void AssertMask()
{
// Action Masks can only be used in Discrete Control.
if (m_BrainParameters.vectorActionSpaceType != SpaceType.Discrete)
{
throw new UnityAgentsException(
"Invalid Action Masking : Can only set action mask for Discrete Control.");
}
var numBranches = m_BrainParameters.vectorActionSize.Length;
for (var branchIndex = 0; branchIndex < numBranches; branchIndex++)
{
if (AreAllActionsMasked(branchIndex))
{
throw new UnityAgentsException(
"Invalid Action Masking : All the actions of branch " + branchIndex +
" are masked.");
}
}
}
/// <summary>
/// Resets the current mask for an agent
/// </summary>
internal void ResetMask()
{
if (m_CurrentMask != null)
{
Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length);
}
}
/// <summary>
/// Checks if all the actions in the input branch are masked
/// </summary>
/// <param name="branch"> The index of the branch to check</param>
/// <returns> True if all the actions of the branch are masked</returns>
bool AreAllActionsMasked(int branch)
{
if (m_CurrentMask == null)
{
return false;
}
var start = m_StartingActionIndices[branch];
var end = m_StartingActionIndices[branch + 1];
for (var i = start; i < end; i++)
{
if (!m_CurrentMask[i])
{
return false;
}
}
return true;
}
}
}

180
com.unity.ml-agents/Runtime/ActionMasker.cs


using System;
using System.Collections.Generic;
using System.Linq;
namespace MLAgents
{
/// <summary>
/// Agents that take discrete actions can explicitly indicate that specific actions
/// are not allowed at a point in time. This enables the agent to indicate that some actions
/// may be illegal (e.g. the King in Chess taking a move to the left if it is already in the
/// left side of the board). This class represents the set of masked actions and provides
/// the utilities for setting and retrieving them.
/// </summary>
public class ActionMasker
{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.
int[] m_StartingActionIndices;
bool[] m_CurrentMask;
readonly BrainParameters m_BrainParameters;
internal ActionMasker(BrainParameters brainParameters)
{
m_BrainParameters = brainParameters;
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the actions passed as argument at the next decision.
/// The actionIndices correspond to the actions the agent will be unable to perform
/// on the branch 0.
/// </summary>
/// <param name="actionIndices">The indices of the masked actions on branch 0.</param>
public void SetActionMask(IEnumerable<int> actionIndices)
{
SetActionMask(0, actionIndices);
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision for the specified
/// action branch. The actionIndex correspond to the action the agent will be unable
/// to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked.</param>
/// <param name="actionIndex">The index of the masked action.</param>
public void SetActionMask(int branch, int actionIndex)
{
SetActionMask(branch, new[] { actionIndex });
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. The actionIndex
/// correspond to the action the agent will be unable to perform on the branch 0.
/// </summary>
/// <param name="actionIndex">The index of the masked action on branch 0</param>
public void SetActionMask(int actionIndex)
{
SetActionMask(0, new[] { actionIndex });
}
/// <summary>
/// Modifies an action mask for discrete control agents. When used, the agent will not be
/// able to perform the actions passed as argument at the next decision for the specified
/// action branch. The actionIndices correspond to the action options the agent will
/// be unable to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndices">The indices of the masked actions</param>
public void SetActionMask(int branch, IEnumerable<int> actionIndices)
{
// If the branch does not exist, raise an error
if (branch >= m_BrainParameters.vectorActionSize.Length)
throw new UnityAgentsException(
"Invalid Action Masking : Branch " + branch + " does not exist.");
var totalNumberActions = m_BrainParameters.vectorActionSize.Sum();
// By default, the masks are null. If we want to specify a new mask, we initialize
// the actionMasks with trues.
if (m_CurrentMask == null)
{
m_CurrentMask = new bool[totalNumberActions];
}
// If this is the first time the masked actions are used, we generate the starting
// indices for each branch.
if (m_StartingActionIndices == null)
{
m_StartingActionIndices = Utilities.CumSum(m_BrainParameters.vectorActionSize);
}
// Perform the masking
foreach (var actionIndex in actionIndices)
{
if (actionIndex >= m_BrainParameters.vectorActionSize[branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
m_CurrentMask[actionIndex + m_StartingActionIndices[branch]] = true;
}
}
/// <summary>
/// Get the current mask for an agent
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
internal bool[] GetMask()
{
if (m_CurrentMask != null)
{
AssertMask();
}
return m_CurrentMask;
}
/// <summary>
/// Makes sure that the current mask is usable.
/// </summary>
void AssertMask()
{
// Action Masks can only be used in Discrete Control.
if (m_BrainParameters.vectorActionSpaceType != SpaceType.Discrete)
{
throw new UnityAgentsException(
"Invalid Action Masking : Can only set action mask for Discrete Control.");
}
var numBranches = m_BrainParameters.vectorActionSize.Length;
for (var branchIndex = 0; branchIndex < numBranches; branchIndex++)
{
if (AreAllActionsMasked(branchIndex))
{
throw new UnityAgentsException(
"Invalid Action Masking : All the actions of branch " + branchIndex +
" are masked.");
}
}
}
/// <summary>
/// Resets the current mask for an agent
/// </summary>
internal void ResetMask()
{
if (m_CurrentMask != null)
{
Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length);
}
}
/// <summary>
/// Checks if all the actions in the input branch are masked
/// </summary>
/// <param name="branch"> The index of the branch to check</param>
/// <returns> True if all the actions of the branch are masked</returns>
bool AreAllActionsMasked(int branch)
{
if (m_CurrentMask == null)
{
return false;
}
var start = m_StartingActionIndices[branch];
var end = m_StartingActionIndices[branch + 1];
for (var i = start; i < end; i++)
{
if (!m_CurrentMask[i])
{
return false;
}
}
return true;
}
}
}

/com.unity.ml-agents/Runtime/ActionMasker.cs.meta → /com.unity.ml-agents/Runtime/DiscreteActionMasker.cs.meta

正在加载...
取消
保存