Andrew Cohen
4 年前
当前提交
598826fe
共有 120 个文件被更改,包括 2958 次插入 和 1925 次删除
-
7.pre-commit-config.yaml
-
2.yamato/com.unity.ml-agents-performance.yml
-
2.yamato/com.unity.ml-agents-test.yml
-
14.yamato/gym-interface-test.yml
-
5.yamato/protobuf-generation-test.yml
-
14.yamato/python-ll-api-test.yml
-
2.yamato/standalone-build-test.yml
-
2.yamato/training-int-tests.yml
-
35Project/Assets/ML-Agents/Examples/Crawler/Prefabs/Crawler.prefab
-
5Project/Assets/ML-Agents/Examples/Crawler/Prefabs/FixedPlatform.prefab
-
12Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
-
1001Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn
-
1001Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn
-
37Project/Assets/ML-Agents/Examples/PushBlock/Prefabs/PushBlockVisualArea.prefab
-
2Project/ProjectSettings/ProjectVersion.txt
-
10README.md
-
2com.unity.ml-agents.extensions/README.md
-
6com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
-
14com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
-
31com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
-
211com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
-
77com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
-
9com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
-
11com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
-
88com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
-
58com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
-
25com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
-
2com.unity.ml-agents.extensions/package.json
-
25com.unity.ml-agents/CHANGELOG.md
-
2com.unity.ml-agents/Documentation~/com.unity.ml-agents.md
-
6com.unity.ml-agents/Runtime/Academy.cs
-
26com.unity.ml-agents/Runtime/Agent.cs
-
1com.unity.ml-agents/Runtime/AssemblyInfo.cs
-
2com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs
-
2com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
-
2com.unity.ml-agents/package.json
-
4docs/Installation-Anaconda-Windows.md
-
6docs/Installation.md
-
2docs/Training-Configuration-File.md
-
81docs/Training-ML-Agents.md
-
2docs/Training-on-Amazon-Web-Service.md
-
22docs/Unity-Inference-Engine.md
-
8docs/Using-Tensorboard.md
-
2gym-unity/gym_unity/__init__.py
-
2ml-agents-envs/mlagents_envs/__init__.py
-
2ml-agents-envs/setup.py
-
18ml-agents/mlagents/model_serialization.py
-
2ml-agents/mlagents/trainers/__init__.py
-
2ml-agents/mlagents/trainers/buffer.py
-
2ml-agents/mlagents/trainers/environment_parameter_manager.py
-
8ml-agents/mlagents/trainers/exception.py
-
14ml-agents/mlagents/trainers/learn.py
-
26ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
-
3ml-agents/mlagents/trainers/ppo/optimizer_torch.py
-
25ml-agents/mlagents/trainers/ppo/trainer.py
-
5ml-agents/mlagents/trainers/sac/optimizer_torch.py
-
49ml-agents/mlagents/trainers/sac/trainer.py
-
12ml-agents/mlagents/trainers/settings.py
-
59ml-agents/mlagents/trainers/stats.py
-
64ml-agents/mlagents/trainers/tests/test_env_param_manager.py
-
42ml-agents/mlagents/trainers/tests/test_stats.py
-
4ml-agents/mlagents/trainers/tests/torch/test_distributions.py
-
9ml-agents/mlagents/trainers/tests/torch/test_networks.py
-
12ml-agents/mlagents/trainers/tests/torch/test_utils.py
-
3ml-agents/mlagents/trainers/torch/decoders.py
-
37ml-agents/mlagents/trainers/torch/distributions.py
-
170ml-agents/mlagents/trainers/torch/encoders.py
-
52ml-agents/mlagents/trainers/torch/utils.py
-
24ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
3ml-agents/tests/yamato/yamato_utils.py
-
1utils/make_readme_table.py
-
3com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs
-
11com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs.meta
-
8com.unity.ml-agents/Runtime/Actuators.meta
-
8com.unity.ml-agents/Tests/Editor/Actuators.meta
-
160docs/images/TensorBoard-download.png
-
20ml-agents/mlagents/trainers/tests/torch/test_layers.py
-
48ml-agents/mlagents/trainers/torch/layers.py
-
181com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
-
3com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta
-
75com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
-
3com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta
-
17com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs
-
3com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta
-
150com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs
-
3com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta
-
415com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
-
3com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta
-
101com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
-
3com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta
-
21com.unity.ml-agents/Runtime/Actuators/IActuator.cs
-
3com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta
-
38com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
-
3com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta
-
72com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
-
3com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta
1001
Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerDynamic.nn
文件差异内容过多而无法显示
查看文件
文件差异内容过多而无法显示
查看文件
1001
Project/Assets/ML-Agents/Examples/Crawler/TFModels/CrawlerStatic.nn
文件差异内容过多而无法显示
查看文件
文件差异内容过多而无法显示
查看文件
|
|||
m_EditorVersion: 2018.4.20f1 |
|||
m_EditorVersion: 2018.4.24f1 |
|
|||
# ML-Agents Extensions |
|||
|
|||
This is a source-only package for new features based on ML-Agents. |
|||
|
|||
More details coming soon. |
|
|||
# Version of the library that will be used to upload to pypi |
|||
__version__ = "0.19.0.dev0" |
|||
__version__ = "0.20.0.dev0" |
|||
|
|||
# Git tag that will be checked to determine whether to trigger upload to pypi |
|||
__release_tag__ = None |
|
|||
# Version of the library that will be used to upload to pypi |
|||
__version__ = "0.19.0.dev0" |
|||
__version__ = "0.20.0.dev0" |
|||
|
|||
# Git tag that will be checked to determine whether to trigger upload to pypi |
|||
__release_tag__ = None |
|
|||
# Version of the library that will be used to upload to pypi |
|||
__version__ = "0.19.0.dev0" |
|||
__version__ = "0.20.0.dev0" |
|||
|
|||
# Git tag that will be checked to determine whether to trigger upload to pypi |
|||
__release_tag__ = None |
|
|||
using System.Runtime.CompilerServices; |
|||
|
|||
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")] |
|
|||
fileFormatVersion: 2 |
|||
guid: 48c8790647c3345e19c57d6c21065112 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 26733e59183b6479e8f0e892a8bf09a4 |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: c7e705f7d549e43c6be18ae809cd6f54 |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
import torch |
|||
|
|||
from mlagents.trainers.torch.layers import Swish, linear_layer, Initialization |
|||
|
|||
|
|||
def test_swish(): |
|||
layer = Swish() |
|||
input_tensor = torch.Tensor([[1, 2, 3], [4, 5, 6]]) |
|||
target_tensor = torch.mul(input_tensor, torch.sigmoid(input_tensor)) |
|||
assert torch.all(torch.eq(layer(input_tensor), target_tensor)) |
|||
|
|||
|
|||
def test_initialization_layer(): |
|||
torch.manual_seed(0) |
|||
# Test Zero |
|||
layer = linear_layer( |
|||
3, 4, kernel_init=Initialization.Zero, bias_init=Initialization.Zero |
|||
) |
|||
assert torch.all(torch.eq(layer.weight.data, torch.zeros_like(layer.weight.data))) |
|||
assert torch.all(torch.eq(layer.bias.data, torch.zeros_like(layer.bias.data))) |
|
|||
import torch |
|||
from enum import Enum |
|||
|
|||
|
|||
class Swish(torch.nn.Module): |
|||
def forward(self, data: torch.Tensor) -> torch.Tensor: |
|||
return torch.mul(data, torch.sigmoid(data)) |
|||
|
|||
|
|||
class Initialization(Enum): |
|||
Zero = 0 |
|||
XavierGlorotNormal = 1 |
|||
XavierGlorotUniform = 2 |
|||
KaimingHeNormal = 3 # also known as Variance scaling |
|||
KaimingHeUniform = 4 |
|||
|
|||
|
|||
_init_methods = { |
|||
Initialization.Zero: torch.zero_, |
|||
Initialization.XavierGlorotNormal: torch.nn.init.xavier_normal_, |
|||
Initialization.XavierGlorotUniform: torch.nn.init.xavier_uniform_, |
|||
Initialization.KaimingHeNormal: torch.nn.init.kaiming_normal_, |
|||
Initialization.KaimingHeUniform: torch.nn.init.kaiming_uniform_, |
|||
} |
|||
|
|||
|
|||
def linear_layer( |
|||
input_size: int, |
|||
output_size: int, |
|||
kernel_init: Initialization = Initialization.XavierGlorotUniform, |
|||
kernel_gain: float = 1.0, |
|||
bias_init: Initialization = Initialization.Zero, |
|||
) -> torch.nn.Module: |
|||
""" |
|||
Creates a torch.nn.Linear module and initializes its weights. |
|||
:param input_size: The size of the input tensor |
|||
:param output_size: The size of the output tensor |
|||
:param kernel_init: The Initialization to use for the weights of the layer |
|||
:param kernel_gain: The multiplier for the weights of the kernel. Note that in |
|||
TensorFlow, calling variance_scaling with scale 0.01 is equivalent to calling |
|||
KaimingHeNormal with kernel_gain of 0.1 |
|||
:param bias_init: The Initialization to use for the weights of the bias layer |
|||
""" |
|||
layer = torch.nn.Linear(input_size, output_size) |
|||
_init_methods[kernel_init](layer.weight.data) |
|||
layer.weight.data *= kernel_gain |
|||
_init_methods[bias_init](layer.bias.data) |
|||
return layer |
|
|||
using System; |
|||
using System.Collections; |
|||
using System.Collections.Generic; |
|||
using System.Diagnostics; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// ActionSegment{T} is a data structure that allows access to a segment of an underlying array
|
|||
/// in order to avoid the copying and allocation of sub-arrays. The segment is defined by
|
|||
/// the offset into the original array, and an length.
|
|||
/// </summary>
|
|||
/// <typeparam name="T">The type of object stored in the underlying <see cref="Array"/></typeparam>
|
|||
internal readonly struct ActionSegment<T> : IEnumerable<T>, IEquatable<ActionSegment<T>> |
|||
where T : struct |
|||
{ |
|||
/// <summary>
|
|||
/// The zero-based offset into the original array at which this segment starts.
|
|||
/// </summary>
|
|||
public readonly int Offset; |
|||
|
|||
/// <summary>
|
|||
/// The number of items this segment can access in the underlying array.
|
|||
/// </summary>
|
|||
public readonly int Length; |
|||
|
|||
/// <summary>
|
|||
/// An Empty segment which has an offset of 0, a Length of 0, and it's underlying array
|
|||
/// is also empty.
|
|||
/// </summary>
|
|||
public static ActionSegment<T> Empty = new ActionSegment<T>(System.Array.Empty<T>(), 0, 0); |
|||
|
|||
static void CheckParameters(T[] actionArray, int offset, int length) |
|||
{ |
|||
#if DEBUG
|
|||
if (offset + length > actionArray.Length) |
|||
{ |
|||
throw new ArgumentOutOfRangeException(nameof(offset), |
|||
$"Arguments offset: {offset} and length: {length} " + |
|||
$"are out of bounds of actionArray: {actionArray.Length}."); |
|||
} |
|||
#endif
|
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Construct an <see cref="ActionSegment{T}"/> with an underlying array
|
|||
/// and offset, and a length.
|
|||
/// </summary>
|
|||
/// <param name="actionArray">The underlying array which this segment has a view into</param>
|
|||
/// <param name="offset">The zero-based offset into the underlying array.</param>
|
|||
/// <param name="length">The length of the segment.</param>
|
|||
public ActionSegment(T[] actionArray, int offset, int length) |
|||
{ |
|||
CheckParameters(actionArray, offset, length); |
|||
Array = actionArray; |
|||
Offset = offset; |
|||
Length = length; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Get the underlying <see cref="Array"/> of this segment.
|
|||
/// </summary>
|
|||
public T[] Array { get; } |
|||
|
|||
/// <summary>
|
|||
/// Allows access to the underlying array using array syntax.
|
|||
/// </summary>
|
|||
/// <param name="index">The zero-based index of the segment.</param>
|
|||
/// <exception cref="IndexOutOfRangeException">Thrown when the index is less than 0 or
|
|||
/// greater than or equal to <see cref="Length"/></exception>
|
|||
public T this[int index] |
|||
{ |
|||
get |
|||
{ |
|||
if (index < 0 || index > Length) |
|||
{ |
|||
throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}"); |
|||
} |
|||
return Array[Offset + index]; |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc cref="IEnumerable{T}.GetEnumerator"/>
|
|||
IEnumerator<T> IEnumerable<T>.GetEnumerator() |
|||
{ |
|||
return new Enumerator(this); |
|||
} |
|||
|
|||
/// <inheritdoc cref="IEnumerable{T}"/>
|
|||
public IEnumerator GetEnumerator() |
|||
{ |
|||
return new Enumerator(this); |
|||
} |
|||
|
|||
/// <inheritdoc cref="ValueType.Equals(object)"/>
|
|||
public override bool Equals(object obj) |
|||
{ |
|||
if (!(obj is ActionSegment<T>)) |
|||
{ |
|||
return false; |
|||
} |
|||
return Equals((ActionSegment<T>)obj); |
|||
} |
|||
|
|||
/// <inheritdoc cref="IEquatable{T}.Equals(T)"/>
|
|||
public bool Equals(ActionSegment<T> other) |
|||
{ |
|||
return Offset == other.Offset && Length == other.Length && Equals(Array, other.Array); |
|||
} |
|||
|
|||
/// <inheritdoc cref="ValueType.GetHashCode"/>
|
|||
public override int GetHashCode() |
|||
{ |
|||
unchecked |
|||
{ |
|||
var hashCode = Offset; |
|||
hashCode = (hashCode * 397) ^ Length; |
|||
hashCode = (hashCode * 397) ^ (Array != null ? Array.GetHashCode() : 0); |
|||
return hashCode; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// A private <see cref="IEnumerator{T}"/> for the <see cref="ActionSegment{T}"/> value type which follows its
|
|||
/// rules of being a view into an underlying <see cref="Array"/>.
|
|||
/// </summary>
|
|||
struct Enumerator : IEnumerator<T> |
|||
{ |
|||
readonly T[] m_Array; |
|||
readonly int m_Start; |
|||
readonly int m_End; // cache Offset + Count, since it's a little slow
|
|||
int m_Current; |
|||
|
|||
internal Enumerator(ActionSegment<T> arraySegment) |
|||
{ |
|||
Debug.Assert(arraySegment.Array != null); |
|||
Debug.Assert(arraySegment.Offset >= 0); |
|||
Debug.Assert(arraySegment.Length >= 0); |
|||
Debug.Assert(arraySegment.Offset + arraySegment.Length <= arraySegment.Array.Length); |
|||
|
|||
m_Array = arraySegment.Array; |
|||
m_Start = arraySegment.Offset; |
|||
m_End = arraySegment.Offset + arraySegment.Length; |
|||
m_Current = arraySegment.Offset - 1; |
|||
} |
|||
|
|||
public bool MoveNext() |
|||
{ |
|||
if (m_Current < m_End) |
|||
{ |
|||
m_Current++; |
|||
return m_Current < m_End; |
|||
} |
|||
return false; |
|||
} |
|||
|
|||
public T Current |
|||
{ |
|||
get |
|||
{ |
|||
if (m_Current < m_Start) |
|||
throw new InvalidOperationException("Enumerator not started."); |
|||
if (m_Current >= m_End) |
|||
throw new InvalidOperationException("Enumerator has reached the end already."); |
|||
return m_Array[m_Current]; |
|||
} |
|||
} |
|||
|
|||
object IEnumerator.Current => Current; |
|||
|
|||
void IEnumerator.Reset() |
|||
{ |
|||
m_Current = m_Start - 1; |
|||
} |
|||
|
|||
public void Dispose() |
|||
{ |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 4fa1432c1ba3460caaa84303a9011ef2 |
|||
timeCreated: 1595869823 |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using System.Linq; |
|||
using Unity.MLAgents.Policies; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// Defines the structure of an Action Space to be used by the Actuator system.
|
|||
/// </summary>
|
|||
internal readonly struct ActionSpec |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// An array of branch sizes for our action space.
|
|||
///
|
|||
/// For an IActuator that uses a Discrete <see cref="SpaceType"/>, the number of
|
|||
/// branches is the Length of the Array and each index contains the branch size.
|
|||
/// The cumulative sum of the total number of discrete actions can be retrieved
|
|||
/// by the <see cref="SumOfDiscreteBranchSizes"/> property.
|
|||
///
|
|||
/// For an IActuator with a Continuous it will be null.
|
|||
/// </summary>
|
|||
public readonly int[] BranchSizes; |
|||
|
|||
/// <summary>
|
|||
/// The number of actions for a Continuous <see cref="SpaceType"/>.
|
|||
/// </summary>
|
|||
public int NumContinuousActions { get; } |
|||
|
|||
/// <summary>
|
|||
/// The number of branches for a Discrete <see cref="SpaceType"/>.
|
|||
/// </summary>
|
|||
public int NumDiscreteActions { get; } |
|||
|
|||
/// <summary>
|
|||
/// Get the total number of Discrete Actions that can be taken by calculating the Sum
|
|||
/// of all of the Discrete Action branch sizes.
|
|||
/// </summary>
|
|||
public int SumOfDiscreteBranchSizes { get; } |
|||
|
|||
/// <summary>
|
|||
/// Creates a Continuous <see cref="ActionSpec"/> with the number of actions available.
|
|||
/// </summary>
|
|||
/// <param name="numActions">The number of actions available.</param>
|
|||
/// <returns>An Continuous ActionSpec initialized with the number of actions available.</returns>
|
|||
public static ActionSpec MakeContinuous(int numActions) |
|||
{ |
|||
var actuatorSpace = new ActionSpec(numActions, 0); |
|||
return actuatorSpace; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Creates a Discrete <see cref="ActionSpec"/> with the array of branch sizes that
|
|||
/// represents the action space.
|
|||
/// </summary>
|
|||
/// <param name="branchSizes">The array of branch sizes for the discrete action space. Each index
|
|||
/// contains the number of actions available for that branch.</param>
|
|||
/// <returns>An Discrete ActionSpec initialized with the array of branch sizes.</returns>
|
|||
public static ActionSpec MakeDiscrete(int[] branchSizes) |
|||
{ |
|||
var numActions = branchSizes.Length; |
|||
var actuatorSpace = new ActionSpec(0, numActions, branchSizes); |
|||
return actuatorSpace; |
|||
} |
|||
|
|||
ActionSpec(int numContinuousActions, int numDiscreteActions, int[] branchSizes = null) |
|||
{ |
|||
NumContinuousActions = numContinuousActions; |
|||
NumDiscreteActions = numDiscreteActions; |
|||
BranchSizes = branchSizes; |
|||
SumOfDiscreteBranchSizes = branchSizes?.Sum() ?? 0; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: ecdd6deefba1416ca149fe09d2a5afd8 |
|||
timeCreated: 1595892361 |
|
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// Editor components for creating Actuators. Generally an IActuator component should
|
|||
/// have a corresponding ActuatorComponent.
|
|||
/// </summary>
|
|||
internal abstract class ActuatorComponent : MonoBehaviour |
|||
{ |
|||
/// <summary>
|
|||
/// Create the IActuator. This is called by the Agent when it is initialized.
|
|||
/// </summary>
|
|||
/// <returns>Created IActuator object.</returns>
|
|||
public abstract IActuator CreateActuator(); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 77cefae5f6d841be9ff80b41293d271b |
|||
timeCreated: 1593017318 |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using System.Linq; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// Implementation of IDiscreteActionMask that allows writing to the action mask from an <see cref="IActuator"/>.
|
|||
/// </summary>
|
|||
internal class ActuatorDiscreteActionMask : IDiscreteActionMask |
|||
{ |
|||
/// When using discrete control, is the starting indices of the actions
|
|||
/// when all the branches are concatenated with each other.
|
|||
int[] m_StartingActionIndices; |
|||
|
|||
int[] m_BranchSizes; |
|||
|
|||
bool[] m_CurrentMask; |
|||
|
|||
IList<IActuator> m_Actuators; |
|||
|
|||
readonly int m_SumOfDiscreteBranchSizes; |
|||
readonly int m_NumBranches; |
|||
|
|||
/// <summary>
|
|||
/// The offset into the branches array that is used when actuators are writing to the action mask.
|
|||
/// </summary>
|
|||
public int CurrentBranchOffset { get; set; } |
|||
|
|||
internal ActuatorDiscreteActionMask(IList<IActuator> actuators, int sumOfDiscreteBranchSizes, int numBranches) |
|||
{ |
|||
m_Actuators = actuators; |
|||
m_SumOfDiscreteBranchSizes = sumOfDiscreteBranchSizes; |
|||
m_NumBranches = numBranches; |
|||
} |
|||
|
|||
/// <inheritdoc cref="IDiscreteActionMask.WriteMask"/>
|
|||
public void WriteMask(int branch, IEnumerable<int> actionIndices) |
|||
{ |
|||
LazyInitialize(); |
|||
|
|||
// Perform the masking
|
|||
foreach (var actionIndex in actionIndices) |
|||
{ |
|||
#if DEBUG
|
|||
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch]) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking: Action Mask is too large for specified branch."); |
|||
} |
|||
#endif
|
|||
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = true; |
|||
} |
|||
} |
|||
|
|||
void LazyInitialize() |
|||
{ |
|||
if (m_BranchSizes == null) |
|||
{ |
|||
m_BranchSizes = new int[m_NumBranches]; |
|||
var start = 0; |
|||
for (var i = 0; i < m_Actuators.Count; i++) |
|||
{ |
|||
var actuator = m_Actuators[i]; |
|||
var branchSizes = actuator.ActionSpec.BranchSizes; |
|||
Array.Copy(branchSizes, 0, m_BranchSizes, start, branchSizes.Length); |
|||
start += branchSizes.Length; |
|||
} |
|||
} |
|||
|
|||
// By default, the masks are null. If we want to specify a new mask, we initialize
|
|||
// the actionMasks with trues.
|
|||
if (m_CurrentMask == null) |
|||
{ |
|||
m_CurrentMask = new bool[m_SumOfDiscreteBranchSizes]; |
|||
} |
|||
|
|||
// If this is the first time the masked actions are used, we generate the starting
|
|||
// indices for each branch.
|
|||
if (m_StartingActionIndices == null) |
|||
{ |
|||
m_StartingActionIndices = Utilities.CumSum(m_BranchSizes); |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc cref="IDiscreteActionMask.GetMask"/>
|
|||
public bool[] GetMask() |
|||
{ |
|||
#if DEBUG
|
|||
if (m_CurrentMask != null) |
|||
{ |
|||
AssertMask(); |
|||
} |
|||
#endif
|
|||
return m_CurrentMask; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Makes sure that the current mask is usable.
|
|||
/// </summary>
|
|||
void AssertMask() |
|||
{ |
|||
#if DEBUG
|
|||
for (var branchIndex = 0; branchIndex < m_NumBranches; branchIndex++) |
|||
{ |
|||
if (AreAllActionsMasked(branchIndex)) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
"Invalid Action Masking : All the actions of branch " + branchIndex + |
|||
" are masked."); |
|||
} |
|||
} |
|||
#endif
|
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Resets the current mask for an agent.
|
|||
/// </summary>
|
|||
public void ResetMask() |
|||
{ |
|||
if (m_CurrentMask != null) |
|||
{ |
|||
Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Checks if all the actions in the input branch are masked.
|
|||
/// </summary>
|
|||
/// <param name="branch"> The index of the branch to check.</param>
|
|||
/// <returns> True if all the actions of the branch are masked.</returns>
|
|||
bool AreAllActionsMasked(int branch) |
|||
{ |
|||
if (m_CurrentMask == null) |
|||
{ |
|||
return false; |
|||
} |
|||
var start = m_StartingActionIndices[branch]; |
|||
var end = m_StartingActionIndices[branch + 1]; |
|||
for (var i = start; i < end; i++) |
|||
{ |
|||
if (!m_CurrentMask[i]) |
|||
{ |
|||
return false; |
|||
} |
|||
} |
|||
return true; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: d2a19e2f43fd4637a38d42b2a5f989f3 |
|||
timeCreated: 1595459316 |
|
|||
using System; |
|||
using System.Collections; |
|||
using System.Collections.Generic; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// A class that manages the delegation of events, action buffers, and action mask for a list of IActuators.
|
|||
/// </summary>
|
|||
internal class ActuatorManager : IList<IActuator> |
|||
{ |
|||
// IActuators managed by this object.
|
|||
IList<IActuator> m_Actuators; |
|||
|
|||
// An implementation of IDiscreteActionMask that allows for writing to it based on an offset.
|
|||
ActuatorDiscreteActionMask m_DiscreteActionMask; |
|||
|
|||
/// <summary>
|
|||
/// Flag used to check if our IActuators are ready for execution.
|
|||
/// </summary>
|
|||
/// <seealso cref="ReadyActuatorsForExecution(IList{IActuator}, int, int, int)"/>
|
|||
bool m_ReadyForExecution; |
|||
|
|||
/// <summary>
|
|||
/// The sum of all of the discrete branches for all of the <see cref="IActuator"/>s in this manager.
|
|||
/// </summary>
|
|||
internal int SumOfDiscreteBranchSizes { get; private set; } |
|||
|
|||
/// <summary>
|
|||
/// The number of the discrete branches for all of the <see cref="IActuator"/>s in this manager.
|
|||
/// </summary>
|
|||
internal int NumDiscreteActions { get; private set; } |
|||
|
|||
/// <summary>
|
|||
/// The number of continuous actions for all of the <see cref="IActuator"/>s in this manager.
|
|||
/// </summary>
|
|||
internal int NumContinuousActions { get; private set; } |
|||
|
|||
/// <summary>
|
|||
/// Returns the total actions which is calculated by <see cref="NumContinuousActions"/> + <see cref="NumDiscreteActions"/>.
|
|||
/// </summary>
|
|||
public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions; |
|||
|
|||
/// <summary>
|
|||
/// Gets the <see cref="IDiscreteActionMask"/> managed by this object.
|
|||
/// </summary>
|
|||
public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask; |
|||
|
|||
/// <summary>
|
|||
/// Returns the previously stored actions for the actuators in this list.
|
|||
/// </summary>
|
|||
public float[] StoredContinuousActions { get; private set; } |
|||
|
|||
/// <summary>
|
|||
/// Returns the previously stored actions for the actuators in this list.
|
|||
/// </summary>
|
|||
public int[] StoredDiscreteActions { get; private set; } |
|||
|
|||
/// <summary>
|
|||
/// Create an ActuatorList with a preset capacity.
|
|||
/// </summary>
|
|||
/// <param name="capacity">The capacity of the list to create.</param>
|
|||
public ActuatorManager(int capacity = 0) |
|||
{ |
|||
m_Actuators = new List<IActuator>(capacity); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <see cref="ReadyActuatorsForExecution(IList{IActuator}, int, int, int)"/>
|
|||
/// </summary>
|
|||
void ReadyActuatorsForExecution() |
|||
{ |
|||
ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes, |
|||
NumDiscreteActions); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// This method validates that all <see cref="IActuator"/>s have unique names and equivalent action space types
|
|||
/// if the `DEBUG` preprocessor macro is defined, and allocates the appropriate buffers to manage the actions for
|
|||
/// all of the <see cref="IActuator"/>s that may live on a particular object.
|
|||
/// </summary>
|
|||
/// <param name="actuators">The list of actuators to validate and allocate buffers for.</param>
|
|||
/// <param name="numContinuousActions">The total number of continuous actions for all of the actuators.</param>
|
|||
/// <param name="sumOfDiscreteBranches">The total sum of the discrete branches for all of the actuators in order
|
|||
/// to be able to allocate an <see cref="IDiscreteActionMask"/>.</param>
|
|||
/// <param name="numDiscreteBranches">The number of discrete branches for all of the actuators.</param>
|
|||
internal void ReadyActuatorsForExecution(IList<IActuator> actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches) |
|||
{ |
|||
if (m_ReadyForExecution) |
|||
{ |
|||
return; |
|||
} |
|||
#if DEBUG
|
|||
// Make sure the names are actually unique
|
|||
// Make sure all Actuators have the same SpaceType
|
|||
ValidateActuators(); |
|||
#endif
|
|||
|
|||
// Sort the Actuators by name to ensure determinism
|
|||
SortActuators(); |
|||
StoredContinuousActions = numContinuousActions == 0 ? Array.Empty<float>() : new float[numContinuousActions]; |
|||
StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty<int>() : new int[numDiscreteBranches]; |
|||
m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches); |
|||
m_ReadyForExecution = true; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Updates the local action buffer with the action buffer passed in. If the buffer
|
|||
/// passed in is null, the local action buffer will be cleared.
|
|||
/// </summary>
|
|||
/// <param name="continuousActionBuffer">The action buffer which contains all of the
|
|||
/// continuous actions for the IActuators in this list.</param>
|
|||
/// <param name="discreteActionBuffer">The action buffer which contains all of the
|
|||
/// discrete actions for the IActuators in this list.</param>
|
|||
public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer) |
|||
{ |
|||
ReadyActuatorsForExecution(); |
|||
UpdateActionArray(continuousActionBuffer, StoredContinuousActions); |
|||
UpdateActionArray(discreteActionBuffer, StoredDiscreteActions); |
|||
} |
|||
|
|||
static void UpdateActionArray<T>(T[] sourceActionBuffer, T[] destination) |
|||
{ |
|||
if (sourceActionBuffer == null || sourceActionBuffer.Length == 0) |
|||
{ |
|||
Array.Clear(destination, 0, destination.Length); |
|||
} |
|||
else |
|||
{ |
|||
Debug.Assert(sourceActionBuffer.Length == destination.Length, |
|||
$"sourceActionBuffer:{sourceActionBuffer.Length} is a different" + |
|||
$" size than destination: {destination.Length}."); |
|||
|
|||
Array.Copy(sourceActionBuffer, destination, destination.Length); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// This method will trigger the writing to the <see cref="IDiscreteActionMask"/> by all of the actuators
|
|||
/// managed by this object.
|
|||
/// </summary>
|
|||
public void WriteActionMask() |
|||
{ |
|||
ReadyActuatorsForExecution(); |
|||
m_DiscreteActionMask.ResetMask(); |
|||
var offset = 0; |
|||
for (var i = 0; i < m_Actuators.Count; i++) |
|||
{ |
|||
var actuator = m_Actuators[i]; |
|||
m_DiscreteActionMask.CurrentBranchOffset = offset; |
|||
actuator.WriteDiscreteActionMask(m_DiscreteActionMask); |
|||
offset += actuator.ActionSpec.NumDiscreteActions; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Iterates through all of the IActuators in this list and calls their
|
|||
/// <see cref="IActionReceiver.OnActionReceived"/> method on them with the appropriate
|
|||
/// <see cref="ActionSegment{T}"/>s depending on their <see cref="IActionReceiver.ActionSpec"/>.
|
|||
/// </summary>
|
|||
public void ExecuteActions() |
|||
{ |
|||
ReadyActuatorsForExecution(); |
|||
var continuousStart = 0; |
|||
var discreteStart = 0; |
|||
for (var i = 0; i < m_Actuators.Count; i++) |
|||
{ |
|||
var actuator = m_Actuators[i]; |
|||
var numContinuousActions = actuator.ActionSpec.NumContinuousActions; |
|||
var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; |
|||
|
|||
var continuousActions = ActionSegment<float>.Empty; |
|||
if (numContinuousActions > 0) |
|||
{ |
|||
continuousActions = new ActionSegment<float>(StoredContinuousActions, |
|||
continuousStart, |
|||
numContinuousActions); |
|||
} |
|||
|
|||
var discreteActions = ActionSegment<int>.Empty; |
|||
if (numDiscreteActions > 0) |
|||
{ |
|||
discreteActions = new ActionSegment<int>(StoredDiscreteActions, |
|||
discreteStart, |
|||
numDiscreteActions); |
|||
} |
|||
|
|||
actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions)); |
|||
continuousStart += numContinuousActions; |
|||
discreteStart += numDiscreteActions; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Resets the <see cref="StoredContinuousActions"/> and <see cref="StoredDiscreteActions"/> buffers to be all
|
|||
/// zeros and calls <see cref="IActuator.ResetData"/> on each <see cref="IActuator"/> managed by this object.
|
|||
/// </summary>
|
|||
public void ResetData() |
|||
{ |
|||
if (!m_ReadyForExecution) |
|||
{ |
|||
return; |
|||
} |
|||
Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length); |
|||
Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length); |
|||
for (var i = 0; i < m_Actuators.Count; i++) |
|||
{ |
|||
m_Actuators[i].ResetData(); |
|||
} |
|||
} |
|||
|
|||
|
|||
/// <summary>
|
|||
/// Sorts the <see cref="IActuator"/>s according to their <see cref="IActuator.GetName"/> value.
|
|||
/// </summary>
|
|||
void SortActuators() |
|||
{ |
|||
((List<IActuator>)m_Actuators).Sort((x, |
|||
y) => x.Name |
|||
.CompareTo(y.Name)); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Validates that the IActuators managed by this object have unique names and equivalent action space types.
|
|||
/// Each Actuator needs to have a unique name in order for this object to ensure that the storage of action
|
|||
/// buffers, and execution of Actuators remains deterministic across different sessions of running.
|
|||
/// </summary>
|
|||
void ValidateActuators() |
|||
{ |
|||
for (var i = 0; i < m_Actuators.Count - 1; i++) |
|||
{ |
|||
Debug.Assert( |
|||
!m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name), |
|||
"Actuator names must be unique."); |
|||
var first = m_Actuators[i].ActionSpec; |
|||
var second = m_Actuators[i + 1].ActionSpec; |
|||
Debug.Assert(first.NumContinuousActions > 0 == second.NumContinuousActions > 0, |
|||
"Actuators on the same Agent must have the same action SpaceType."); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Helper method to update bookkeeping items around buffer management for actuators added to this object.
|
|||
/// </summary>
|
|||
/// <param name="actuatorItem">The IActuator to keep bookkeeping for.</param>
|
|||
void AddToBufferSizes(IActuator actuatorItem) |
|||
{ |
|||
if (actuatorItem == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions; |
|||
NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions; |
|||
SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Helper method to update bookkeeping items around buffer management for actuators removed from this object.
|
|||
/// </summary>
|
|||
/// <param name="actuatorItem">The IActuator to keep bookkeeping for.</param>
|
|||
void SubtractFromBufferSize(IActuator actuatorItem) |
|||
{ |
|||
if (actuatorItem == null) |
|||
{ |
|||
return; |
|||
} |
|||
|
|||
NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions; |
|||
NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions; |
|||
SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Sets all of the bookkeeping items back to 0.
|
|||
/// </summary>
|
|||
void ClearBufferSizes() |
|||
{ |
|||
NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0; |
|||
} |
|||
|
|||
/********************************************************************************* |
|||
* IList implementation that delegates to m_Actuators List. * |
|||
*********************************************************************************/ |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IEnumerable{T}.GetEnumerator"/>
|
|||
/// </summary>
|
|||
public IEnumerator<IActuator> GetEnumerator() |
|||
{ |
|||
return m_Actuators.GetEnumerator(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IList{T}.GetEnumerator"/>
|
|||
/// </summary>
|
|||
IEnumerator IEnumerable.GetEnumerator() |
|||
{ |
|||
return ((IEnumerable)m_Actuators).GetEnumerator(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.Add"/>
|
|||
/// </summary>
|
|||
/// <param name="item"></param>
|
|||
public void Add(IActuator item) |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot add to the ActuatorManager after its buffers have been initialized"); |
|||
m_Actuators.Add(item); |
|||
AddToBufferSizes(item); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.Clear"/>
|
|||
/// </summary>
|
|||
public void Clear() |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot clear the ActuatorManager after its buffers have been initialized"); |
|||
m_Actuators.Clear(); |
|||
ClearBufferSizes(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.Contains"/>
|
|||
/// </summary>
|
|||
public bool Contains(IActuator item) |
|||
{ |
|||
return m_Actuators.Contains(item); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.CopyTo"/>
|
|||
/// </summary>
|
|||
public void CopyTo(IActuator[] array, int arrayIndex) |
|||
{ |
|||
m_Actuators.CopyTo(array, arrayIndex); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.Remove"/>
|
|||
/// </summary>
|
|||
public bool Remove(IActuator item) |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot remove from the ActuatorManager after its buffers have been initialized"); |
|||
if (m_Actuators.Remove(item)) |
|||
{ |
|||
SubtractFromBufferSize(item); |
|||
return true; |
|||
} |
|||
return false; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.Count"/>
|
|||
/// </summary>
|
|||
public int Count => m_Actuators.Count; |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="ICollection{T}.IsReadOnly"/>
|
|||
/// </summary>
|
|||
public bool IsReadOnly => m_Actuators.IsReadOnly; |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IList{T}.IndexOf"/>
|
|||
/// </summary>
|
|||
public int IndexOf(IActuator item) |
|||
{ |
|||
return m_Actuators.IndexOf(item); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IList{T}.Insert"/>
|
|||
/// </summary>
|
|||
public void Insert(int index, IActuator item) |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot insert into the ActuatorManager after its buffers have been initialized"); |
|||
m_Actuators.Insert(index, item); |
|||
AddToBufferSizes(item); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IList{T}.RemoveAt"/>
|
|||
/// </summary>
|
|||
public void RemoveAt(int index) |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot remove from the ActuatorManager after its buffers have been initialized"); |
|||
var actuator = m_Actuators[index]; |
|||
SubtractFromBufferSize(actuator); |
|||
m_Actuators.RemoveAt(index); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IList{T}.this"/>
|
|||
/// </summary>
|
|||
public IActuator this[int index] |
|||
{ |
|||
get => m_Actuators[index]; |
|||
set |
|||
{ |
|||
Debug.Assert(m_ReadyForExecution == false, |
|||
"Cannot modify the ActuatorManager after its buffers have been initialized"); |
|||
var old = m_Actuators[index]; |
|||
SubtractFromBufferSize(old); |
|||
m_Actuators[index] = value; |
|||
AddToBufferSizes(value); |
|||
} |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 7bb5b1e3779d4342a8e70f6e3c1d67cc |
|||
timeCreated: 1593031463 |
|
|||
using System; |
|||
using System.Linq; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// A structure that wraps the <see cref="ActionSegment{T}"/>s for a particular <see cref="IActionReceiver"/> and is
|
|||
/// used when <see cref="IActionReceiver.OnActionReceived"/> is called.
|
|||
/// </summary>
|
|||
internal readonly struct ActionBuffers |
|||
{ |
|||
/// <summary>
|
|||
/// An empty action buffer.
|
|||
/// </summary>
|
|||
public static ActionBuffers Empty = new ActionBuffers(ActionSegment<float>.Empty, ActionSegment<int>.Empty); |
|||
|
|||
/// <summary>
|
|||
/// Holds the Continuous <see cref="ActionSegment{T}"/> to be used by an <see cref="IActionReceiver"/>.
|
|||
/// </summary>
|
|||
public ActionSegment<float> ContinuousActions { get; } |
|||
|
|||
/// <summary>
|
|||
/// Holds the Discrete <see cref="ActionSegment{T}"/> to be used by an <see cref="IActionReceiver"/>.
|
|||
/// </summary>
|
|||
public ActionSegment<int> DiscreteActions { get; } |
|||
|
|||
/// <summary>
|
|||
/// Construct an <see cref="ActionBuffers"/> instance with the continuous and discrete actions that will
|
|||
/// be used.
|
|||
/// </summary>
|
|||
/// <param name="continuousActions">The continuous actions to send to an <see cref="IActionReceiver"/>.</param>
|
|||
/// <param name="discreteActions">The discrete actions to send to an <see cref="IActionReceiver"/>.</param>
|
|||
public ActionBuffers(ActionSegment<float> continuousActions, ActionSegment<int> discreteActions) |
|||
{ |
|||
ContinuousActions = continuousActions; |
|||
DiscreteActions = discreteActions; |
|||
} |
|||
|
|||
/// <inheritdoc cref="ValueType.Equals(object)"/>
|
|||
public override bool Equals(object obj) |
|||
{ |
|||
if (!(obj is ActionBuffers)) |
|||
{ |
|||
return false; |
|||
} |
|||
|
|||
var ab = (ActionBuffers)obj; |
|||
return ab.ContinuousActions.SequenceEqual(ContinuousActions) && |
|||
ab.DiscreteActions.SequenceEqual(DiscreteActions); |
|||
} |
|||
|
|||
/// <inheritdoc cref="ValueType.GetHashCode"/>
|
|||
public override int GetHashCode() |
|||
{ |
|||
unchecked |
|||
{ |
|||
return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode(); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// An interface that describes an object that can receive actions from a Reinforcement Learning network.
|
|||
/// </summary>
|
|||
internal interface IActionReceiver |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// The specification of the Action space for this IActionReceiver.
|
|||
/// </summary>
|
|||
/// <seealso cref="ActionSpec"/>
|
|||
ActionSpec ActionSpec { get; } |
|||
|
|||
/// <summary>
|
|||
/// Method called in order too allow object to execute actions based on the
|
|||
/// <see cref="ActionBuffers"/> contents. The structure of the contents in the <see cref="ActionBuffers"/>
|
|||
/// are defined by the <see cref="ActionSpec"/>.
|
|||
/// </summary>
|
|||
/// <param name="actionBuffers">The data structure containing the action buffers for this object.</param>
|
|||
void OnActionReceived(ActionBuffers actionBuffers); |
|||
|
|||
/// <summary>
|
|||
/// Implement `WriteDiscreteActionMask()` to modify the masks for discrete
|
|||
/// actions. When using discrete actions, the agent will not perform the masked
|
|||
/// action.
|
|||
/// </summary>
|
|||
/// <param name="actionMask">
|
|||
/// The action mask for the agent.
|
|||
/// </param>
|
|||
/// <remarks>
|
|||
/// When using Discrete Control, you can prevent the Agent from using a certain
|
|||
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask"/>.
|
|||
///
|
|||
/// See [Agents - Actions] for more information on masking actions.
|
|||
///
|
|||
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#actions
|
|||
/// </remarks>
|
|||
/// <seealso cref="IActionReceiver.OnActionReceived"/>
|
|||
void WriteDiscreteActionMask(IDiscreteActionMask actionMask); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: b25a5b3027c9476ea1a310241be0f10f |
|||
timeCreated: 1594756775 |
|
|||
using System; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// Abstraction that facilitates the execution of actions.
|
|||
/// </summary>
|
|||
internal interface IActuator : IActionReceiver |
|||
{ |
|||
int TotalNumberOfActions { get; } |
|||
|
|||
/// <summary>
|
|||
/// Gets the name of this IActuator which will be used to sort it.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
string Name { get; } |
|||
|
|||
void ResetData(); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 780d7f0a675f44bfa784b370025b51c3 |
|||
timeCreated: 1592848317 |
|
|||
using System.Collections.Generic; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
/// <summary>
|
|||
/// Interface for writing a mask to disable discrete actions for agents for the next decision.
|
|||
/// </summary>
|
|||
internal interface IDiscreteActionMask |
|||
{ |
|||
/// <summary>
|
|||
/// Modifies an action mask for discrete control agents.
|
|||
/// </summary>
|
|||
/// <remarks>
|
|||
/// When used, the agent will not be able to perform the actions passed as argument
|
|||
/// at the next decision for the specified action branch. The actionIndices correspond
|
|||
/// to the action options the agent will be unable to perform.
|
|||
///
|
|||
/// See [Agents - Actions] for more information on masking actions.
|
|||
///
|
|||
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_2_docs/docs/Learning-Environment-Design-Agents.md#actions
|
|||
/// </remarks>
|
|||
/// <param name="branch">The branch for which the actions will be masked.</param>
|
|||
/// <param name="actionIndices">The indices of the masked actions.</param>
|
|||
void WriteMask(int branch, IEnumerable<int> actionIndices); |
|||
|
|||
/// <summary>
|
|||
/// Get the current mask for an agent.
|
|||
/// </summary>
|
|||
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
|
|||
/// actions.</returns>
|
|||
bool[] GetMask(); |
|||
|
|||
/// <summary>
|
|||
/// Resets the current mask for an agent.
|
|||
/// </summary>
|
|||
void ResetMask(); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 1bc4e4b71bf4470789488fab2ee65388 |
|||
timeCreated: 1595369065 |
|
|||
using System; |
|||
|
|||
using Unity.MLAgents.Policies; |
|||
|
|||
namespace Unity.MLAgents.Actuators |
|||
{ |
|||
internal class VectorActuator : IActuator |
|||
{ |
|||
IActionReceiver m_ActionReceiver; |
|||
|
|||
ActionBuffers m_ActionBuffers; |
|||
internal ActionBuffers ActionBuffers |
|||
{ |
|||
get => m_ActionBuffers; |
|||
private set => m_ActionBuffers = value; |
|||
} |
|||
|
|||
public VectorActuator(IActionReceiver actionReceiver, |
|||
int[] vectorActionSize, |
|||
SpaceType spaceType, |
|||
string name = "VectorActuator") |
|||
{ |
|||
m_ActionReceiver = actionReceiver; |
|||
string suffix; |
|||
switch (spaceType) |
|||
{ |
|||
case SpaceType.Continuous: |
|||
ActionSpec = ActionSpec.MakeContinuous(vectorActionSize[0]); |
|||
suffix = "-Continuous"; |
|||
break; |
|||
case SpaceType.Discrete: |
|||
ActionSpec = ActionSpec.MakeDiscrete(vectorActionSize); |
|||
suffix = "-Discrete"; |
|||
break; |
|||
default: |
|||
throw new ArgumentOutOfRangeException(nameof(spaceType), |
|||
spaceType, |
|||
"Unknown enum value."); |
|||
} |
|||
Name = name + suffix; |
|||
} |
|||
|
|||
public void ResetData() |
|||
{ |
|||
m_ActionBuffers = ActionBuffers.Empty; |
|||
} |
|||
|
|||
public void OnActionReceived(ActionBuffers actionBuffers) |
|||
{ |
|||
ActionBuffers = actionBuffers; |
|||
m_ActionReceiver.OnActionReceived(ActionBuffers); |
|||
} |
|||
|
|||
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) |
|||
{ |
|||
m_ActionReceiver.WriteDiscreteActionMask(actionMask); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns the number of discrete branches + the number of continuous actions.
|
|||
/// </summary>
|
|||
public int TotalNumberOfActions => ActionSpec.NumContinuousActions + |
|||
ActionSpec.NumDiscreteActions; |
|||
|
|||
/// <summary>
|
|||
/// <inheritdoc cref="IActionReceiver.ActionSpec"/>
|
|||
/// </summary>
|
|||
public ActionSpec ActionSpec { get; } |
|||
|
|||
public string Name { get; } |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: ff7a3292c0b24b23b3f1c0eeb690ec4c |
|||
timeCreated: 1593023833 |
部分文件因为文件数量过多而无法显示
撰写
预览
正在加载...
取消
保存
Reference in new issue