浏览代码
Develop action masking (#1080)
Develop action masking (#1080)
* [Initial Commit] Modified the model.py file and the ppo/trainer.py file to use masked actions * Preliminary modifications to the python side of the code to enable action masking * Preliminary modifications to the C# side of the code to enable action masking * Preliminary modifications to the communication side of the code to enable action masking * Implemented action masking for BC Note : The actions of the teacher are not masked * More error messages for the action masking * fix pytests * Added Documentation * Address comment * Addressed Comments on docs * Addressed second comment on docs * Addressed comments for the python side of the code * Created the action masker and associated unit tests * Addressed comments on the C# side * Addressed the comment regarding action_masking_name * Addressed the comments/develop-generalizationTraining-TrainerController
GitHub
6 年前
当前提交
ded0d8c7
共有 22 个文件被更改,包括 540 次插入 和 22 次删除
-
24docs/Learning-Environment-Design-Agents.md
-
11python/communicator_objects/agent_info_proto_pb2.py
-
6python/tests/test_bc.py
-
15python/tests/test_ppo.py
-
3python/unityagents/brain.py
-
9python/unityagents/environment.py
-
4python/unitytrainers/bc/models.py
-
6python/unitytrainers/bc/trainer.py
-
21python/unitytrainers/models.py
-
10python/unitytrainers/ppo/trainer.py
-
64unity-environment/Assets/ML-Agents/Scripts/Agent.cs
-
1unity-environment/Assets/ML-Agents/Scripts/Batcher.cs
-
29unity-environment/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentInfoProto.cs
-
44unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs
-
8unity-environment/Assets/ML-Agents/Editor/Tests.meta
-
154unity-environment/Assets/ML-Agents/Scripts/ActionMasker.cs
-
3unity-environment/Assets/ML-Agents/Scripts/ActionMasker.cs.meta
-
139unity-environment/Assets/ML-Agents/Editor/Tests/EditModeTestActionMasker.cs
-
11unity-environment/Assets/ML-Agents/Editor/Tests/EditModeTestActionMasker.cs.meta
-
0/unity-environment/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
-
0/unity-environment/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs.meta
|
|||
fileFormatVersion: 2 |
|||
guid: 172fcc71d343247a9a91d5b54dd21cd6 |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using System.Linq; |
|||
|
|||
namespace MLAgents |
|||
{ |
|||
public class ActionMasker |
|||
{ |
|||
/// When using discrete control, is the starting indices of the actions
|
|||
/// when all the branches are concatenated with each other.
|
|||
private int[] _startingActionIndices; |
|||
|
|||
private bool[] _currentMask; |
|||
|
|||
private readonly BrainParameters _brainParameters; |
|||
|
|||
public ActionMasker(BrainParameters brainParameters) |
|||
{ |
|||
this._brainParameters = brainParameters; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Modifies an action mask for discrete control agents. When used, the agent will not be
|
|||
/// able to perform the action passed as argument at the next decision. If no branch is
|
|||
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
|
|||
/// to the action the agent will be unable to perform.
|
|||
/// </summary>
|
|||
/// <param name="branch">The branch for which the actions will be masked</param>
|
|||
/// <param name="actionIndices">The indices of the masked actions</param>
|
|||
public void SetActionMask(int branch, IEnumerable<int> actionIndices) |
|||
{ |
|||
// If the branch does not exist, raise an error
|
|||
if (branch >= _brainParameters.vectorActionSize.Length ) |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking : Branch "+branch+" does not exist."); |
|||
|
|||
int totalNumberActions = _brainParameters.vectorActionSize.Sum(); |
|||
|
|||
// By default, the masks are null. If we want to specify a new mask, we initialize
|
|||
// the actionMasks with trues.
|
|||
if (_currentMask == null) |
|||
{ |
|||
_currentMask = new bool[totalNumberActions]; |
|||
} |
|||
|
|||
// If this is the first time the masked actions are used, we generate the starting
|
|||
// indices for each branch.
|
|||
if (_startingActionIndices == null) |
|||
{ |
|||
_startingActionIndices = CreateActionStartinIndices(); |
|||
} |
|||
|
|||
// Perform the masking
|
|||
foreach (var actionIndex in actionIndices) |
|||
{ |
|||
if (actionIndex >= _brainParameters.vectorActionSize[branch]) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking: Action Mask is too large for specified branch."); |
|||
} |
|||
_currentMask[actionIndex + _startingActionIndices[branch]] = true; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Get the current mask for an agent
|
|||
/// </summary>
|
|||
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
|
|||
/// actions.</returns>
|
|||
public bool[] GetMask() |
|||
{ |
|||
AssertMask(); |
|||
return _currentMask; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Makes sure that the current mask is usable.
|
|||
/// </summary>
|
|||
private void AssertMask() |
|||
{ |
|||
// Action Masks can only be used in Discrete Control.
|
|||
if (_brainParameters.vectorActionSpaceType != SpaceType.discrete) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking : Can only set action mask for Discrete Control."); |
|||
} |
|||
|
|||
var numBranches = _brainParameters.vectorActionSize.Length; |
|||
for (var branchIndex = 0 ; branchIndex < numBranches; branchIndex++ ) |
|||
{ |
|||
if (AreAllActionsMasked(branchIndex)) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking : All the actions of branch " + branchIndex + |
|||
" are masked."); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Resets the current mask for an agent
|
|||
/// </summary>
|
|||
public void ResetMask() |
|||
{ |
|||
if (_currentMask != null) |
|||
{ |
|||
Array.Clear(_currentMask, 0, _currentMask.Length); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Generates an array containing the starting indicies of each branch in the vector action
|
|||
/// Makes a cumulative sum.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
private int[] CreateActionStartinIndices() |
|||
{ |
|||
var vectorActionSize = _brainParameters.vectorActionSize; |
|||
var runningSum = 0; |
|||
var result = new int[vectorActionSize.Length + 1]; |
|||
for (var actionIndex = 0; |
|||
actionIndex < vectorActionSize.Length; actionIndex++) |
|||
{ |
|||
runningSum += vectorActionSize[actionIndex]; |
|||
result[actionIndex + 1] = runningSum; |
|||
} |
|||
return result; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Checks if all the actions in the input branch are masked
|
|||
/// </summary>
|
|||
/// <param name="branch"> The index of the branch to check</param>
|
|||
/// <returns> True if all the actions of the branch are masked</returns>
|
|||
private bool AreAllActionsMasked(int branch) |
|||
{ |
|||
if (_currentMask == null) |
|||
{ |
|||
return false; |
|||
} |
|||
var start = _startingActionIndices[branch]; |
|||
var end = _startingActionIndices[branch + 1]; |
|||
for (var i = start; i < end; i++) |
|||
{ |
|||
if (!_currentMask[i]) |
|||
{ |
|||
return false; |
|||
} |
|||
} |
|||
return true; |
|||
|
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 8a0ec4ccf4ee450da7766f65228d5460 |
|||
timeCreated: 1534530911 |
|
|||
using NUnit.Framework; |
|||
|
|||
namespace MLAgents.Tests |
|||
{ |
|||
public class EditModeTestActionMasker |
|||
{ |
|||
[Test] |
|||
public void Contruction() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
var masker = new ActionMasker(bp); |
|||
Assert.IsNotNull(masker); |
|||
} |
|||
|
|||
[Test] |
|||
public void FailsWithContinuous() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.continuous; |
|||
bp.vectorActionSize = new int[1] {4}; |
|||
var masker = new ActionMasker(bp); |
|||
masker.SetActionMask(0, new int[1] {0}); |
|||
Assert.Catch<UnityAgentsException>(() => masker.GetMask()); |
|||
|
|||
} |
|||
|
|||
[Test] |
|||
public void NullMask() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
var masker = new ActionMasker(bp); |
|||
var mask = masker.GetMask(); |
|||
Assert.IsNull(mask); |
|||
} |
|||
|
|||
[Test] |
|||
public void FirstBranchMask() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
bp.vectorActionSize = new int[3] {4, 5, 6}; |
|||
var masker = new ActionMasker(bp); |
|||
var mask = masker.GetMask(); |
|||
Assert.IsNull(mask); |
|||
masker.SetActionMask(0, new int[]{1,2,3}); |
|||
mask = masker.GetMask(); |
|||
Assert.IsFalse(mask[0]); |
|||
Assert.IsTrue(mask[1]); |
|||
Assert.IsTrue(mask[2]); |
|||
Assert.IsTrue(mask[3]); |
|||
Assert.IsFalse(mask[4]); |
|||
Assert.AreEqual(mask.Length, 15); |
|||
} |
|||
|
|||
[Test] |
|||
public void SecondBranchMask() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
bp.vectorActionSize = new int[3] {4, 5, 6}; |
|||
var masker = new ActionMasker(bp); |
|||
bool[] mask = masker.GetMask(); |
|||
masker.SetActionMask(1, new int[]{1,2,3}); |
|||
mask = masker.GetMask(); |
|||
Assert.IsFalse(mask[0]); |
|||
Assert.IsFalse(mask[4]); |
|||
Assert.IsTrue(mask[5]); |
|||
Assert.IsTrue(mask[6]); |
|||
Assert.IsTrue(mask[7]); |
|||
Assert.IsFalse(mask[8]); |
|||
Assert.IsFalse(mask[9]); |
|||
} |
|||
|
|||
[Test] |
|||
public void MaskReset() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
bp.vectorActionSize = new int[3] {4, 5, 6}; |
|||
var masker = new ActionMasker(bp); |
|||
var mask = masker.GetMask(); |
|||
masker.SetActionMask(1, new int[3]{1,2,3}); |
|||
mask = masker.GetMask(); |
|||
masker.ResetMask(); |
|||
mask = masker.GetMask(); |
|||
for (var i = 0; i < 15; i++) |
|||
{ |
|||
Assert.IsFalse(mask[i]); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void ThrowsError() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
bp.vectorActionSize = new int[3] {4, 5, 6}; |
|||
var masker = new ActionMasker(bp); |
|||
|
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.SetActionMask(0, new int[1]{5})); |
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.SetActionMask(1, new int[1]{5})); |
|||
masker.SetActionMask(2, new int[1] {5}); |
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.SetActionMask(3, new int[1]{1})); |
|||
masker.GetMask(); |
|||
masker.ResetMask(); |
|||
masker.SetActionMask(0, new int[4] {0, 1, 2, 3}); |
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.GetMask()); |
|||
} |
|||
|
|||
[Test] |
|||
public void MultipleMaskEdit() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
bp.vectorActionSize = new int[3] {4, 5, 6}; |
|||
var masker = new ActionMasker(bp); |
|||
masker.SetActionMask(0, new int[2] {0, 1}); |
|||
masker.SetActionMask(0, new int[1] {3}); |
|||
masker.SetActionMask(2, new int[1] {1}); |
|||
var mask = masker.GetMask(); |
|||
for (var i = 0; i < 15; i++) |
|||
{ |
|||
if ((i == 0) || (i == 1) || (i == 3)|| (i == 10)) |
|||
{ |
|||
Assert.IsTrue(mask[i]); |
|||
} |
|||
else |
|||
{ |
|||
Assert.IsFalse(mask[i]); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 2e2810ee6c8c64fb39abdf04b5d17f50 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
撰写
预览
正在加载...
取消
保存
Reference in new issue