浏览代码
Merge remote-tracking branch 'upstream/develop' into develop-flat-code-restructure
/develop-generalizationTraining-TrainerController
Merge remote-tracking branch 'upstream/develop' into develop-flat-code-restructure
/develop-generalizationTraining-TrainerController
Deric Pang
6 年前
当前提交
cdb41480
共有 43 个文件被更改,包括 581 次插入 和 56 次删除
-
64MLAgentsSDK/Assets/ML-Agents/Scripts/Agent.cs
-
9MLAgentsSDK/Assets/ML-Agents/Scripts/Batcher.cs
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentActionProto.cs.meta
-
29MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentInfoProto.cs
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/AgentInfoProto.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainParametersProto.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/BrainTypeProto.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/CommandProto.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/EngineConfigurationProto.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/EnvironmentParametersProto.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/Header.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/ResolutionProto.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/SpaceTypeProto.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityInput.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityMessage.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityOutput.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationInput.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInitializationOutput.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlInput.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityRlOutput.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternal.cs.meta
-
2MLAgentsSDK/Assets/ML-Agents/Scripts/CommunicatorObjects/UnityToExternalGrpc.cs.meta
-
44MLAgentsSDK/Assets/ML-Agents/Scripts/CoreBrainInternal.cs
-
24docs/Learning-Environment-Design-Agents.md
-
5docs/localized/zh-CN/README.md
-
3python/mlagents/mlagents/envs/brain.py
-
11python/mlagents/mlagents/envs/communicator_objects/agent_info_proto_pb2.py
-
10python/mlagents/mlagents/envs/environment.py
-
4python/mlagents/mlagents/trainers/bc/models.py
-
6python/mlagents/mlagents/trainers/bc/trainer.py
-
21python/mlagents/mlagents/trainers/models.py
-
10python/mlagents/mlagents/trainers/ppo/trainer.py
-
6python/mlagents/tests/trainers/test_bc.py
-
15python/mlagents/tests/trainers/test_ppo.py
-
8MLAgentsSDK/Assets/ML-Agents/Editor/Tests.meta
-
154MLAgentsSDK/Assets/ML-Agents/Scripts/ActionMasker.cs
-
3MLAgentsSDK/Assets/ML-Agents/Scripts/ActionMasker.cs.meta
-
139MLAgentsSDK/Assets/ML-Agents/Editor/Tests/EditModeTestActionMasker.cs
-
11MLAgentsSDK/Assets/ML-Agents/Editor/Tests/EditModeTestActionMasker.cs.meta
-
11MLAgentsSDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs.meta
-
12MLAgentsSDK/Assets/ML-Agents/Editor/MLAgentsEditModeTest.cs.meta
-
0/MLAgentsSDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
|
|||
fileFormatVersion: 2 |
|||
guid: 172fcc71d343247a9a91d5b54dd21cd6 |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using System.Linq; |
|||
|
|||
namespace MLAgents |
|||
{ |
|||
public class ActionMasker |
|||
{ |
|||
/// When using discrete control, is the starting indices of the actions
|
|||
/// when all the branches are concatenated with each other.
|
|||
private int[] _startingActionIndices; |
|||
|
|||
private bool[] _currentMask; |
|||
|
|||
private readonly BrainParameters _brainParameters; |
|||
|
|||
public ActionMasker(BrainParameters brainParameters) |
|||
{ |
|||
this._brainParameters = brainParameters; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Modifies an action mask for discrete control agents. When used, the agent will not be
|
|||
/// able to perform the action passed as argument at the next decision. If no branch is
|
|||
/// specified, the default branch will be 0. The actionIndex or actionIndices correspond
|
|||
/// to the action the agent will be unable to perform.
|
|||
/// </summary>
|
|||
/// <param name="branch">The branch for which the actions will be masked</param>
|
|||
/// <param name="actionIndices">The indices of the masked actions</param>
|
|||
public void SetActionMask(int branch, IEnumerable<int> actionIndices) |
|||
{ |
|||
// If the branch does not exist, raise an error
|
|||
if (branch >= _brainParameters.vectorActionSize.Length ) |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking : Branch "+branch+" does not exist."); |
|||
|
|||
int totalNumberActions = _brainParameters.vectorActionSize.Sum(); |
|||
|
|||
// By default, the masks are null. If we want to specify a new mask, we initialize
|
|||
// the actionMasks with trues.
|
|||
if (_currentMask == null) |
|||
{ |
|||
_currentMask = new bool[totalNumberActions]; |
|||
} |
|||
|
|||
// If this is the first time the masked actions are used, we generate the starting
|
|||
// indices for each branch.
|
|||
if (_startingActionIndices == null) |
|||
{ |
|||
_startingActionIndices = CreateActionStartinIndices(); |
|||
} |
|||
|
|||
// Perform the masking
|
|||
foreach (var actionIndex in actionIndices) |
|||
{ |
|||
if (actionIndex >= _brainParameters.vectorActionSize[branch]) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking: Action Mask is too large for specified branch."); |
|||
} |
|||
_currentMask[actionIndex + _startingActionIndices[branch]] = true; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Get the current mask for an agent
|
|||
/// </summary>
|
|||
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
|
|||
/// actions.</returns>
|
|||
public bool[] GetMask() |
|||
{ |
|||
AssertMask(); |
|||
return _currentMask; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Makes sure that the current mask is usable.
|
|||
/// </summary>
|
|||
private void AssertMask() |
|||
{ |
|||
// Action Masks can only be used in Discrete Control.
|
|||
if (_brainParameters.vectorActionSpaceType != SpaceType.discrete) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking : Can only set action mask for Discrete Control."); |
|||
} |
|||
|
|||
var numBranches = _brainParameters.vectorActionSize.Length; |
|||
for (var branchIndex = 0 ; branchIndex < numBranches; branchIndex++ ) |
|||
{ |
|||
if (AreAllActionsMasked(branchIndex)) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking : All the actions of branch " + branchIndex + |
|||
" are masked."); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Resets the current mask for an agent
|
|||
/// </summary>
|
|||
public void ResetMask() |
|||
{ |
|||
if (_currentMask != null) |
|||
{ |
|||
Array.Clear(_currentMask, 0, _currentMask.Length); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Generates an array containing the starting indicies of each branch in the vector action
|
|||
/// Makes a cumulative sum.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
private int[] CreateActionStartinIndices() |
|||
{ |
|||
var vectorActionSize = _brainParameters.vectorActionSize; |
|||
var runningSum = 0; |
|||
var result = new int[vectorActionSize.Length + 1]; |
|||
for (var actionIndex = 0; |
|||
actionIndex < vectorActionSize.Length; actionIndex++) |
|||
{ |
|||
runningSum += vectorActionSize[actionIndex]; |
|||
result[actionIndex + 1] = runningSum; |
|||
} |
|||
return result; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Checks if all the actions in the input branch are masked
|
|||
/// </summary>
|
|||
/// <param name="branch"> The index of the branch to check</param>
|
|||
/// <returns> True if all the actions of the branch are masked</returns>
|
|||
private bool AreAllActionsMasked(int branch) |
|||
{ |
|||
if (_currentMask == null) |
|||
{ |
|||
return false; |
|||
} |
|||
var start = _startingActionIndices[branch]; |
|||
var end = _startingActionIndices[branch + 1]; |
|||
for (var i = start; i < end; i++) |
|||
{ |
|||
if (!_currentMask[i]) |
|||
{ |
|||
return false; |
|||
} |
|||
} |
|||
return true; |
|||
|
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 8a0ec4ccf4ee450da7766f65228d5460 |
|||
timeCreated: 1534530911 |
|
|||
using NUnit.Framework; |
|||
|
|||
namespace MLAgents.Tests |
|||
{ |
|||
public class EditModeTestActionMasker |
|||
{ |
|||
[Test] |
|||
public void Contruction() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
var masker = new ActionMasker(bp); |
|||
Assert.IsNotNull(masker); |
|||
} |
|||
|
|||
[Test] |
|||
public void FailsWithContinuous() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.continuous; |
|||
bp.vectorActionSize = new int[1] {4}; |
|||
var masker = new ActionMasker(bp); |
|||
masker.SetActionMask(0, new int[1] {0}); |
|||
Assert.Catch<UnityAgentsException>(() => masker.GetMask()); |
|||
|
|||
} |
|||
|
|||
[Test] |
|||
public void NullMask() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
var masker = new ActionMasker(bp); |
|||
var mask = masker.GetMask(); |
|||
Assert.IsNull(mask); |
|||
} |
|||
|
|||
[Test] |
|||
public void FirstBranchMask() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
bp.vectorActionSize = new int[3] {4, 5, 6}; |
|||
var masker = new ActionMasker(bp); |
|||
var mask = masker.GetMask(); |
|||
Assert.IsNull(mask); |
|||
masker.SetActionMask(0, new int[]{1,2,3}); |
|||
mask = masker.GetMask(); |
|||
Assert.IsFalse(mask[0]); |
|||
Assert.IsTrue(mask[1]); |
|||
Assert.IsTrue(mask[2]); |
|||
Assert.IsTrue(mask[3]); |
|||
Assert.IsFalse(mask[4]); |
|||
Assert.AreEqual(mask.Length, 15); |
|||
} |
|||
|
|||
[Test] |
|||
public void SecondBranchMask() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
bp.vectorActionSize = new int[3] {4, 5, 6}; |
|||
var masker = new ActionMasker(bp); |
|||
bool[] mask = masker.GetMask(); |
|||
masker.SetActionMask(1, new int[]{1,2,3}); |
|||
mask = masker.GetMask(); |
|||
Assert.IsFalse(mask[0]); |
|||
Assert.IsFalse(mask[4]); |
|||
Assert.IsTrue(mask[5]); |
|||
Assert.IsTrue(mask[6]); |
|||
Assert.IsTrue(mask[7]); |
|||
Assert.IsFalse(mask[8]); |
|||
Assert.IsFalse(mask[9]); |
|||
} |
|||
|
|||
[Test] |
|||
public void MaskReset() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
bp.vectorActionSize = new int[3] {4, 5, 6}; |
|||
var masker = new ActionMasker(bp); |
|||
var mask = masker.GetMask(); |
|||
masker.SetActionMask(1, new int[3]{1,2,3}); |
|||
mask = masker.GetMask(); |
|||
masker.ResetMask(); |
|||
mask = masker.GetMask(); |
|||
for (var i = 0; i < 15; i++) |
|||
{ |
|||
Assert.IsFalse(mask[i]); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void ThrowsError() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
bp.vectorActionSize = new int[3] {4, 5, 6}; |
|||
var masker = new ActionMasker(bp); |
|||
|
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.SetActionMask(0, new int[1]{5})); |
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.SetActionMask(1, new int[1]{5})); |
|||
masker.SetActionMask(2, new int[1] {5}); |
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.SetActionMask(3, new int[1]{1})); |
|||
masker.GetMask(); |
|||
masker.ResetMask(); |
|||
masker.SetActionMask(0, new int[4] {0, 1, 2, 3}); |
|||
Assert.Catch<UnityAgentsException>( |
|||
() => masker.GetMask()); |
|||
} |
|||
|
|||
[Test] |
|||
public void MultipleMaskEdit() |
|||
{ |
|||
var bp = new BrainParameters(); |
|||
bp.vectorActionSpaceType = SpaceType.discrete; |
|||
bp.vectorActionSize = new int[3] {4, 5, 6}; |
|||
var masker = new ActionMasker(bp); |
|||
masker.SetActionMask(0, new int[2] {0, 1}); |
|||
masker.SetActionMask(0, new int[1] {3}); |
|||
masker.SetActionMask(2, new int[1] {1}); |
|||
var mask = masker.GetMask(); |
|||
for (var i = 0; i < 15; i++) |
|||
{ |
|||
if ((i == 0) || (i == 1) || (i == 3)|| (i == 10)) |
|||
{ |
|||
Assert.IsTrue(mask[i]); |
|||
} |
|||
else |
|||
{ |
|||
Assert.IsFalse(mask[i]); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 2e2810ee6c8c64fb39abdf04b5d17f50 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 3170fcbfa5f4d4a8ca82c50c750e9083 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 64f5b117b5f304a4281f16eb904311fd |
|||
timeCreated: 1518706577 |
|||
licenseType: Free |
|||
MonoImporter: |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
撰写
预览
正在加载...
取消
保存
Reference in new issue