浏览代码

Pass action masker as input to CollectObservations (#3389)

* Sentencing Action masking the same as observations
I am rather unsure about the doubling of the CollectObservation methods (and the copy pasta that comes along)
Need to edit the documentation and the migrating doc once we agree we want to do this

* Addressing the comments

* Improvements to the documentation

* Editing the documentation
/asymm-envs
GitHub 5 年前
当前提交
92a8aed2
共有 5 个文件被更改,包括 92 次插入68 次删除
  1. 14
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  2. 48
      com.unity.ml-agents/Runtime/ActionMasker.cs
  3. 89
      com.unity.ml-agents/Runtime/Agent.cs
  4. 6
      docs/Learning-Environment-Design-Agents.md
  5. 3
      docs/Migrating.md

14
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


{
}
public override void CollectObservations(VectorSensor sensor)
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
{
// There are no numeric observations to collect as this environment uses visual
// observations.

{
SetMask();
SetMask(actionMasker);
}
}

void SetMask()
void SetMask(ActionMasker actionMasker)
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.position.x;

if (positionX == 0)
{
SetActionMask(k_Left);
actionMasker.SetActionMask(k_Left);
SetActionMask(k_Right);
actionMasker.SetActionMask(k_Right);
SetActionMask(k_Down);
actionMasker.SetActionMask(k_Down);
SetActionMask(k_Up);
actionMasker.SetActionMask(k_Up);
}
}

48
com.unity.ml-agents/Runtime/ActionMasker.cs


namespace MLAgents
{
internal class ActionMasker
public class ActionMasker
{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.

}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the actions passed as argument at the next decision.
/// The actionIndices correspond to the actions the agent will be unable to perform
/// on the branch 0.
/// </summary>
/// <param name="actionIndices">The indices of the masked actions on branch 0</param>
public void SetActionMask(IEnumerable<int> actionIndices)
{
SetActionMask(0, actionIndices);
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision for the specified
/// action branch. The actionIndex correspond to the action the agent will be unable
/// to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndex">The index of the masked action</param>
public void SetActionMask(int branch, int actionIndex)
{
SetActionMask(branch, new[] { actionIndex });
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. The actionIndex
/// correspond to the action the agent will be unable to perform on the branch 0.
/// </summary>
/// <param name="actionIndex">The index of the masked action on branch 0</param>
public void SetActionMask(int actionIndex)
{
SetActionMask(0, new[] { actionIndex });
}
/// <summary>
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// able to perform the actions passed as argument at the next decision for the specified
/// action branch. The actionIndices correspond to the action options the agent will
/// be unable to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndices">The indices of the masked actions</param>

/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
public bool[] GetMask()
internal bool[] GetMask()
{
if (m_CurrentMask != null)
{

/// <summary>
/// Resets the current mask for an agent
/// </summary>
public void ResetMask()
internal void ResetMask()
{
if (m_CurrentMask != null)
{

89
com.unity.ml-agents/Runtime/Agent.cs


UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{
CollectObservations(collectObservationsSensor);
CollectObservations(collectObservationsSensor, m_ActionMasker);
}
m_Info.actionMasks = m_ActionMasker.GetMask();

/// - <see cref="AddObservation(float)"/>
/// - <see cref="AddObservation(Vector3)"/>
/// - <see cref="AddObservation(Vector2)"/>
/// - <see>
/// <cref>AddVectorObs(float[])</cref>
/// </see>
/// - <see>
/// <cref>AddVectorObs(List{float})</cref>
/// </see>
/// - <see cref="AddObservation(Quaternion)"/>
/// - <see cref="AddObservation(bool)"/>
/// - <see cref="AddOneHotObservation(int, int)"/>

}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// Collects the vector observations of the agent.
/// The agent observation describes the current environment from the
/// perspective of the agent.
/// <param name="actionIndices">The indices of the masked actions on branch 0</param>
protected void SetActionMask(IEnumerable<int> actionIndices)
{
m_ActionMasker.SetActionMask(0, actionIndices);
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// </summary>
/// <param name="actionIndex">The index of the masked action on branch 0</param>
protected void SetActionMask(int actionIndex)
{
m_ActionMasker.SetActionMask(0, new[] { actionIndex });
}
/// <summary>
/// Sets an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndex">The index of the masked action</param>
protected void SetActionMask(int branch, int actionIndex)
{
m_ActionMasker.SetActionMask(branch, new[] { actionIndex });
}
/// <summary>
/// Modifies an action mask for discrete control agents. When used, the agent will not be
/// able to perform the action passed as argument at the next decision. If no branch is
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
/// to the action the agent will be unable to perform.
/// </summary>
/// <param name="branch">The branch for which the actions will be masked</param>
/// <param name="actionIndices">The indices of the masked actions</param>
protected void SetActionMask(int branch, IEnumerable<int> actionIndices)
/// <remarks>
/// An agents observation is any environment information that helps
/// the Agent achieve its goal. For example, for a fighting Agent, its
/// observation could include distances to friends or enemies, or the
/// current level of ammunition at its disposal.
/// Recall that an Agent may attach vector or visual observations.
/// Vector observations are added by calling the provided helper methods
/// on the VectorSensor input:
/// - <see cref="AddObservation(int)"/>
/// - <see cref="AddObservation(float)"/>
/// - <see cref="AddObservation(Vector3)"/>
/// - <see cref="AddObservation(Vector2)"/>
/// - <see cref="AddObservation(Quaternion)"/>
/// - <see cref="AddObservation(bool)"/>
/// - <see cref="AddOneHotObservation(int, int)"/>
/// Depending on your environment, any combination of these helpers can
/// be used. They just need to be used in the exact same order each time
/// this method is called and the resulting size of the vector observation
/// needs to match the vectorObservationSize attribute of the linked Brain.
/// Visual observations are implicitly added from the cameras attached to
/// the Agent.
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it. You can call the following method on the ActionMasker
/// input :
/// - <see cref="SetActionMask(int branch, IEnumerable<int> actionIndices)"/>
/// - <see cref="SetActionMask(int branch, int actionIndex)"/>
/// - <see cref="SetActionMask(IEnumerable<int> actionIndices)"/>
/// - <see cref="SetActionMask(int branch, int actionIndex)"/>
/// The branch input is the index of the action, actionIndices are the indices of the
/// invalid options for that action.
/// </remarks>
public virtual void CollectObservations(VectorSensor sensor, ActionMasker actionMasker)
m_ActionMasker.SetActionMask(branch, actionIndices);
CollectObservations(sensor);
}
/// <summary>

6
docs/Learning-Environment-Design-Agents.md


neural network, the Agent will be unable to perform the specified action. Note
that when the Agent is controlled by its Heuristic, the Agent will
still be able to decide to perform the masked action. In order to mask an
action, call the method `SetActionMask` within the `CollectObservation` method :
action, call the method `SetActionMask` on the optional `ActionMasker` argument of the `CollectObservation` method :
SetActionMask(branch, actionIndices)
public override void CollectObservations(VectorSensor sensor, ActionMasker actionMasker){
actionMasker.SetActionMask(branch, actionIndices)
}
```
Where:

3
docs/Migrating.md


* The `Agent.CollectObservations()` virtual method now takes as input a `VectorSensor` sensor as argument. The `Agent.AddVectorObs()` methods were removed.
* The `Monitor` class has been moved to the Examples Project. (It was prone to errors during testing)
* The `MLAgents.Sensor` namespace has been removed. All sensors now belong to the `MLAgents` namespace.
* The `SetActionMask` method must now be called on the optional `ActionMasker` argument of the `CollectObservations` method. (We now consider an action mask as a type of observation)
* Replace your calls to `SetActionMask` on your Agent to `ActionMasker.SetActionMask` in `CollectObservations`
## Migrating from 0.13 to 0.14

正在加载...
取消
保存