浏览代码

Merge branch 'master' into hh/develop/loco-walker-variable-speed

/active-variablespeed
HH 4 年前
当前提交
d4bd7fe6
共有 48 个文件被更改,包括 1981 次插入116 次删除
  1. 2
      Project/ProjectSettings/ProjectVersion.txt
  2. 2
      com.unity.ml-agents.extensions/package.json
  3. 21
      com.unity.ml-agents/CHANGELOG.md
  4. 2
      com.unity.ml-agents/Runtime/Academy.cs
  5. 2
      com.unity.ml-agents/package.json
  6. 2
      docs/Training-ML-Agents.md
  7. 22
      docs/Unity-Inference-Engine.md
  8. 8
      docs/Using-Tensorboard.md
  9. 2
      gym-unity/gym_unity/__init__.py
  10. 2
      ml-agents-envs/mlagents_envs/__init__.py
  11. 18
      ml-agents/mlagents/model_serialization.py
  12. 2
      ml-agents/mlagents/trainers/__init__.py
  13. 59
      ml-agents/mlagents/trainers/stats.py
  14. 42
      ml-agents/mlagents/trainers/tests/test_stats.py
  15. 8
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  16. 8
      com.unity.ml-agents/Runtime/Actuators.meta
  17. 8
      com.unity.ml-agents/Tests/Editor/Actuators.meta
  18. 160
      docs/images/TensorBoard-download.png
  19. 181
      com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs
  20. 3
      com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta
  21. 75
      com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs
  22. 3
      com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta
  23. 17
      com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs
  24. 3
      com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta
  25. 150
      com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs
  26. 3
      com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta
  27. 415
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
  28. 3
      com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta
  29. 101
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
  30. 3
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta
  31. 21
      com.unity.ml-agents/Runtime/Actuators/IActuator.cs
  32. 3
      com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta
  33. 38
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
  34. 3
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta
  35. 72
      com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
  36. 3
      com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta
  37. 55
      com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs
  38. 3
      com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs.meta
  39. 114
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs
  40. 3
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs.meta
  41. 310
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
  42. 3
      com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs.meta
  43. 38
      com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
  44. 3
      com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs.meta
  45. 98
      com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs
  46. 3
      com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs.meta

2
Project/ProjectSettings/ProjectVersion.txt


m_EditorVersion: 2018.4.18f1
m_EditorVersion: 2018.4.24f1

2
com.unity.ml-agents.extensions/package.json


"unity": "2018.4",
"description": "A source-only package for new features based on ML-Agents",
"dependencies": {
"com.unity.ml-agents": "1.2.0-preview"
"com.unity.ml-agents": "1.3.0-preview"
}
}

21
com.unity.ml-agents/CHANGELOG.md


### Major Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- The minimum supported python version for ml-agents-envs was changed to 3.6.1. (#4244)
### Minor Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
### Bug Fixes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
## [1.3.0-preview] 2020-08-12
### Major Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- The minimum supported Python version for ml-agents-envs was changed to 3.6.1. (#4244)
- The interaction between EnvManager and TrainerController was changed; EnvManager.advance() was split into to stages,
and TrainerController now uses the results from the first stage to handle new behavior names. This change speeds up
Python training by approximately 5-10%. (#4259)

#### ml-agents / ml-agents-envs / gym-unity (Python)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- The versions of `numpy` supported by ml-agents-envs were changed to disallow 1.19.0 or later. This was done to reflect
a similar change in TensorFlow's requirements. (#4274)
- CSV statistics writer was removed (#4300).
### Bug Fixes
#### com.unity.ml-agents (C#)

2
com.unity.ml-agents/Runtime/Academy.cs


/// Unity package version of com.unity.ml-agents.
/// This must match the version string in package.json and is checked in a unit test.
/// </summary>
internal const string k_PackageVersion = "1.2.0-preview";
internal const string k_PackageVersion = "1.3.0-preview";
const int k_EditorTrainingPort = 5004;

2
com.unity.ml-agents/package.json


{
"name": "com.unity.ml-agents",
"displayName": "ML Agents",
"version": "1.2.0-preview",
"version": "1.3.0-preview",
"unity": "2018.4",
"description": "Use state-of-the-art machine learning to create intelligent character behaviors in any Unity environment (games, robotics, film, etc.).",
"dependencies": {

2
docs/Training-ML-Agents.md


mlagents-learn config/ppo/3DBall_randomize.yaml --run-id=3D-Ball-randomize
```
We can observe progress and metrics via Tensorboard.
We can observe progress and metrics via TensorBoard.
#### Curriculum

22
docs/Unity-Inference-Engine.md


[compute shaders](https://docs.unity3d.com/Manual/class-ComputeShader.html) to
run the neural network within Unity.
**Note**: The ML-Agents Toolkit only supports the models created with our
trainers.
## Supported devices
See the Unity Inference Engine documentation for a list of the

**Note:** For most of the models generated with the ML-Agents Toolkit, CPU will
be faster than GPU. You should use the GPU only if you use the ResNet visual
encoder or have a large number of agents with visual observations.
# Unsupported use cases
## Externally trained models
The ML-Agents Toolkit only supports the models created with our trainers. Model
loading expects certain conventions for constants and tensor names. While it is
possible to construct a model that follows these conventions, we don't provide
any additional help for this. More details can be found in
[TensorNames.cs](https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/com.unity.ml-agents/Runtime/Inference/TensorNames.cs)
and
[BarracudaModelParamLoader.cs](https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs).
If you wish to run inference on an externally trained model, you should use
Barracuda directly, instead of trying to run it through ML-Agents.
## Model inference outside of Unity
We do not provide support for inference anywhere outside of Unity. The
`frozen_graph_def.pb` and `.onnx` files produced by training are open formats
for TensorFlow and ONNX respectively; if you wish to convert these to another
format or run inference with them, refer to their documentation.

8
docs/Using-Tensorboard.md


skill level between two players. In a proper training run, the ELO of the
agent should steadily increase.
## Exporting Data from TensorBoard
To export timeseries data in CSV or JSON format, check the "Show data download
links" in the upper left. This will enable download links below each chart.
![Example TensorBoard Run](images/TensorBoard-download.png)
To get custom metrics from a C# environment into Tensorboard, you can use the
To get custom metrics from a C# environment into TensorBoard, you can use the
`StatsRecorder`:
```csharp

2
gym-unity/gym_unity/__init__.py


# Version of the library that will be used to upload to pypi
__version__ = "0.19.0.dev0"
__version__ = "0.20.0.dev0"
# Git tag that will be checked to determine whether to trigger upload to pypi
__release_tag__ = None

2
ml-agents-envs/mlagents_envs/__init__.py


# Version of the library that will be used to upload to pypi
__version__ = "0.19.0.dev0"
__version__ = "0.20.0.dev0"
# Git tag that will be checked to determine whether to trigger upload to pypi
__release_tag__ = None

18
ml-agents/mlagents/model_serialization.py


from distutils.util import strtobool
import os
import shutil
from typing import Any, List, Set, NamedTuple
from distutils.version import LooseVersion

return strtobool(val)
except Exception:
return False
def copy_model_files(source_nn_path: str, destination_nn_path: str) -> None:
"""
Copy the .nn file at the given source to the destination.
Also copies the corresponding .onnx file if it exists.
"""
shutil.copyfile(source_nn_path, destination_nn_path)
logger.info(f"Copied {source_nn_path} to {destination_nn_path}.")
# Copy the onnx file if it exists
source_onnx_path = os.path.splitext(source_nn_path)[0] + ".onnx"
destination_onnx_path = os.path.splitext(destination_nn_path)[0] + ".onnx"
try:
shutil.copyfile(source_onnx_path, destination_onnx_path)
logger.info(f"Copied {source_onnx_path} to {destination_onnx_path}.")
except OSError:
pass

2
ml-agents/mlagents/trainers/__init__.py


# Version of the library that will be used to upload to pypi
__version__ = "0.19.0.dev0"
__version__ = "0.20.0.dev0"
# Git tag that will be checked to determine whether to trigger upload to pypi
__release_tag__ = None

59
ml-agents/mlagents/trainers/stats.py


from typing import List, Dict, NamedTuple, Any, Optional
import numpy as np
import abc
import csv
import os
import time
from threading import RLock

"""
Add a generic property to the StatsWriter. This could be e.g. a Dict of hyperparameters,
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
with all types of properties. For instance, a TB writer doesn't need a max step, nor should
we write hyperparameters to the CSV.
with all types of properties. For instance, a TB writer doesn't need a max step.
:param category: The category that the property belongs to.
:param type: The type of property.
:param value: The property itself.

return None
class CSVWriter(StatsWriter):
def __init__(self, base_dir: str, required_fields: List[str] = None):
"""
A StatsWriter that writes to a Tensorboard summary.
:param base_dir: The directory within which to place the CSV file, which will be {base_dir}/{category}.csv.
:param required_fields: If provided, the CSV writer won't write until these fields have statistics to write for
them.
"""
# We need to keep track of the fields in the CSV, as all rows need the same fields.
self.csv_fields: Dict[str, List[str]] = {}
self.required_fields = required_fields if required_fields else []
self.base_dir: str = base_dir
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int
) -> None:
if self._maybe_create_csv_file(category, list(values.keys())):
row = [str(step)]
# Only record the stats that showed up in the first valid row
for key in self.csv_fields[category]:
_val = values.get(key, None)
row.append(str(_val.mean) if _val else "None")
with open(self._get_filepath(category), "a") as file:
writer = csv.writer(file)
writer.writerow(row)
def _maybe_create_csv_file(self, category: str, keys: List[str]) -> bool:
"""
If no CSV file exists and the keys have the required values,
make the CSV file and write hte title row.
Returns True if there is now (or already is) a valid CSV file.
"""
if category not in self.csv_fields:
summary_dir = self.base_dir
os.makedirs(summary_dir, exist_ok=True)
# Only store if the row contains the required fields
if all(item in keys for item in self.required_fields):
self.csv_fields[category] = keys
with open(self._get_filepath(category), "w") as file:
title_row = ["Steps"]
title_row.extend(keys)
writer = csv.writer(file)
writer.writerow(title_row)
return True
return False
return True
def _get_filepath(self, category: str) -> str:
file_dir = os.path.join(self.base_dir, category + ".csv")
return file_dir
class StatsReporter:
writers: List[StatsWriter] = []
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list))

"""
Add a generic property to the StatsReporter. This could be e.g. a Dict of hyperparameters,
a max step count, a trainer type, etc. Note that not all StatsWriters need to be compatible
with all types of properties. For instance, a TB writer doesn't need a max step, nor should
we write hyperparameters to the CSV.
with all types of properties. For instance, a TB writer doesn't need a max step.
:param key: The type of property.
:param value: The property itself.
"""

42
ml-agents/mlagents/trainers/tests/test_stats.py


import pytest
import tempfile
import unittest
import csv
CSVWriter,
StatsSummary,
GaugeWriter,
ConsoleWriter,

tb_writer = TensorboardWriter(tmp_path, clear_past_data=True)
tb_writer.write_stats("category1", {"key1": statssummary1}, 10)
assert len(os.listdir(os.path.join(tmp_path, "category1"))) == 1
def test_csv_writer():
# Test write_stats
category = "category1"
with tempfile.TemporaryDirectory(prefix="unittest-") as base_dir:
csv_writer = CSVWriter(base_dir, required_fields=["key1", "key2"])
statssummary1 = StatsSummary(mean=1.0, std=1.0, num=1)
csv_writer.write_stats("category1", {"key1": statssummary1}, 10)
# Test that the filewriter has been created and the directory has been created.
filewriter_dir = "{basedir}/{category}.csv".format(
basedir=base_dir, category=category
)
# The required keys weren't in the stats
assert not os.path.exists(filewriter_dir)
csv_writer.write_stats(
"category1", {"key1": statssummary1, "key2": statssummary1}, 10
)
csv_writer.write_stats(
"category1", {"key1": statssummary1, "key2": statssummary1}, 20
)
# The required keys were in the stats
assert os.path.exists(filewriter_dir)
with open(filewriter_dir) as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
line_count = 0
for row in csv_reader:
if line_count == 0:
assert "key1" in row
assert "key2" in row
assert "Steps" in row
line_count += 1
else:
assert len(row) == 3
line_count += 1
assert line_count == 3
def test_gauge_stat_writer_sanitize():

8
ml-agents/mlagents/trainers/trainer/rl_trainer.py


import abc
import time
import attr
from mlagents.model_serialization import SerializationSettings
from mlagents.model_serialization import SerializationSettings, copy_model_files
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpoint,
NNCheckpointManager,

"Trainer has multiple policies, but default behavior only saves the first."
)
policy = list(self.policies.values())[0]
settings = SerializationSettings(policy.model_path, self.brain_name)
# Copy the checkpointed model files to the final output location
copy_model_files(model_checkpoint.file_path, f"{policy.model_path}.nn")
policy.save(policy.model_path, settings)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
@abc.abstractmethod

8
com.unity.ml-agents/Runtime/Actuators.meta


fileFormatVersion: 2
guid: 26733e59183b6479e8f0e892a8bf09a4
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

8
com.unity.ml-agents/Tests/Editor/Actuators.meta


fileFormatVersion: 2
guid: c7e705f7d549e43c6be18ae809cd6f54
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

160
docs/images/TensorBoard-download.png

之前 之后
宽度: 709  |  高度: 393  |  大小: 46 KiB

181
com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs


using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// ActionSegment{T} is a data structure that allows access to a segment of an underlying array
/// in order to avoid the copying and allocation of sub-arrays. The segment is defined by
/// the offset into the original array, and an length.
/// </summary>
/// <typeparam name="T">The type of object stored in the underlying <see cref="Array"/></typeparam>
internal readonly struct ActionSegment<T> : IEnumerable<T>, IEquatable<ActionSegment<T>>
where T : struct
{
/// <summary>
/// The zero-based offset into the original array at which this segment starts.
/// </summary>
public readonly int Offset;
/// <summary>
/// The number of items this segment can access in the underlying array.
/// </summary>
public readonly int Length;
/// <summary>
/// An Empty segment which has an offset of 0, a Length of 0, and it's underlying array
/// is also empty.
/// </summary>
public static ActionSegment<T> Empty = new ActionSegment<T>(System.Array.Empty<T>(), 0, 0);
static void CheckParameters(T[] actionArray, int offset, int length)
{
#if DEBUG
if (offset + length > actionArray.Length)
{
throw new ArgumentOutOfRangeException(nameof(offset),
$"Arguments offset: {offset} and length: {length} " +
$"are out of bounds of actionArray: {actionArray.Length}.");
}
#endif
}
/// <summary>
/// Construct an <see cref="ActionSegment{T}"/> with an underlying array
/// and offset, and a length.
/// </summary>
/// <param name="actionArray">The underlying array which this segment has a view into</param>
/// <param name="offset">The zero-based offset into the underlying array.</param>
/// <param name="length">The length of the segment.</param>
public ActionSegment(T[] actionArray, int offset, int length)
{
CheckParameters(actionArray, offset, length);
Array = actionArray;
Offset = offset;
Length = length;
}
/// <summary>
/// Get the underlying <see cref="Array"/> of this segment.
/// </summary>
public T[] Array { get; }
/// <summary>
/// Allows access to the underlying array using array syntax.
/// </summary>
/// <param name="index">The zero-based index of the segment.</param>
/// <exception cref="IndexOutOfRangeException">Thrown when the index is less than 0 or
/// greater than or equal to <see cref="Length"/></exception>
public T this[int index]
{
get
{
if (index < 0 || index > Length)
{
throw new IndexOutOfRangeException($"Index out of bounds, expected a number between 0 and {Length}");
}
return Array[Offset + index];
}
}
/// <inheritdoc cref="IEnumerable{T}.GetEnumerator"/>
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{
return new Enumerator(this);
}
/// <inheritdoc cref="IEnumerable{T}"/>
public IEnumerator GetEnumerator()
{
return new Enumerator(this);
}
/// <inheritdoc cref="ValueType.Equals(object)"/>
public override bool Equals(object obj)
{
if (!(obj is ActionSegment<T>))
{
return false;
}
return Equals((ActionSegment<T>)obj);
}
/// <inheritdoc cref="IEquatable{T}.Equals(T)"/>
public bool Equals(ActionSegment<T> other)
{
return Offset == other.Offset && Length == other.Length && Equals(Array, other.Array);
}
/// <inheritdoc cref="ValueType.GetHashCode"/>
public override int GetHashCode()
{
unchecked
{
var hashCode = Offset;
hashCode = (hashCode * 397) ^ Length;
hashCode = (hashCode * 397) ^ (Array != null ? Array.GetHashCode() : 0);
return hashCode;
}
}
/// <summary>
/// A private <see cref="IEnumerator{T}"/> for the <see cref="ActionSegment{T}"/> value type which follows its
/// rules of being a view into an underlying <see cref="Array"/>.
/// </summary>
struct Enumerator : IEnumerator<T>
{
readonly T[] m_Array;
readonly int m_Start;
readonly int m_End; // cache Offset + Count, since it's a little slow
int m_Current;
internal Enumerator(ActionSegment<T> arraySegment)
{
Debug.Assert(arraySegment.Array != null);
Debug.Assert(arraySegment.Offset >= 0);
Debug.Assert(arraySegment.Length >= 0);
Debug.Assert(arraySegment.Offset + arraySegment.Length <= arraySegment.Array.Length);
m_Array = arraySegment.Array;
m_Start = arraySegment.Offset;
m_End = arraySegment.Offset + arraySegment.Length;
m_Current = arraySegment.Offset - 1;
}
public bool MoveNext()
{
if (m_Current < m_End)
{
m_Current++;
return m_Current < m_End;
}
return false;
}
public T Current
{
get
{
if (m_Current < m_Start)
throw new InvalidOperationException("Enumerator not started.");
if (m_Current >= m_End)
throw new InvalidOperationException("Enumerator has reached the end already.");
return m_Array[m_Current];
}
}
object IEnumerator.Current => Current;
void IEnumerator.Reset()
{
m_Current = m_Start - 1;
}
public void Dispose()
{
}
}
}
}

3
com.unity.ml-agents/Runtime/Actuators/ActionSegment.cs.meta


fileFormatVersion: 2
guid: 4fa1432c1ba3460caaa84303a9011ef2
timeCreated: 1595869823

75
com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs


using System;
using System.Collections.Generic;
using System.Linq;
using Unity.MLAgents.Policies;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Defines the structure of an Action Space to be used by the Actuator system.
/// </summary>
internal readonly struct ActionSpec
{
/// <summary>
/// An array of branch sizes for our action space.
///
/// For an IActuator that uses a Discrete <see cref="SpaceType"/>, the number of
/// branches is the Length of the Array and each index contains the branch size.
/// The cumulative sum of the total number of discrete actions can be retrieved
/// by the <see cref="SumOfDiscreteBranchSizes"/> property.
///
/// For an IActuator with a Continuous it will be null.
/// </summary>
public readonly int[] BranchSizes;
/// <summary>
/// The number of actions for a Continuous <see cref="SpaceType"/>.
/// </summary>
public int NumContinuousActions { get; }
/// <summary>
/// The number of branches for a Discrete <see cref="SpaceType"/>.
/// </summary>
public int NumDiscreteActions { get; }
/// <summary>
/// Get the total number of Discrete Actions that can be taken by calculating the Sum
/// of all of the Discrete Action branch sizes.
/// </summary>
public int SumOfDiscreteBranchSizes { get; }
/// <summary>
/// Creates a Continuous <see cref="ActionSpec"/> with the number of actions available.
/// </summary>
/// <param name="numActions">The number of actions available.</param>
/// <returns>An Continuous ActionSpec initialized with the number of actions available.</returns>
public static ActionSpec MakeContinuous(int numActions)
{
var actuatorSpace = new ActionSpec(numActions, 0);
return actuatorSpace;
}
/// <summary>
/// Creates a Discrete <see cref="ActionSpec"/> with the array of branch sizes that
/// represents the action space.
/// </summary>
/// <param name="branchSizes">The array of branch sizes for the discrete action space. Each index
/// contains the number of actions available for that branch.</param>
/// <returns>An Discrete ActionSpec initialized with the array of branch sizes.</returns>
public static ActionSpec MakeDiscrete(int[] branchSizes)
{
var numActions = branchSizes.Length;
var actuatorSpace = new ActionSpec(0, numActions, branchSizes);
return actuatorSpace;
}
ActionSpec(int numContinuousActions, int numDiscreteActions, int[] branchSizes = null)
{
NumContinuousActions = numContinuousActions;
NumDiscreteActions = numDiscreteActions;
BranchSizes = branchSizes;
SumOfDiscreteBranchSizes = branchSizes?.Sum() ?? 0;
}
}
}

3
com.unity.ml-agents/Runtime/Actuators/ActionSpec.cs.meta


fileFormatVersion: 2
guid: ecdd6deefba1416ca149fe09d2a5afd8
timeCreated: 1595892361

17
com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs


using UnityEngine;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Editor components for creating Actuators. Generally an IActuator component should
/// have a corresponding ActuatorComponent.
/// </summary>
internal abstract class ActuatorComponent : MonoBehaviour
{
/// <summary>
/// Create the IActuator. This is called by the Agent when it is initialized.
/// </summary>
/// <returns>Created IActuator object.</returns>
public abstract IActuator CreateActuator();
}
}

3
com.unity.ml-agents/Runtime/Actuators/ActuatorComponent.cs.meta


fileFormatVersion: 2
guid: 77cefae5f6d841be9ff80b41293d271b
timeCreated: 1593017318

150
com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs


using System;
using System.Collections.Generic;
using System.Linq;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Implementation of IDiscreteActionMask that allows writing to the action mask from an <see cref="IActuator"/>.
/// </summary>
internal class ActuatorDiscreteActionMask : IDiscreteActionMask
{
/// When using discrete control, is the starting indices of the actions
/// when all the branches are concatenated with each other.
int[] m_StartingActionIndices;
int[] m_BranchSizes;
bool[] m_CurrentMask;
IList<IActuator> m_Actuators;
readonly int m_SumOfDiscreteBranchSizes;
readonly int m_NumBranches;
/// <summary>
/// The offset into the branches array that is used when actuators are writing to the action mask.
/// </summary>
public int CurrentBranchOffset { get; set; }
internal ActuatorDiscreteActionMask(IList<IActuator> actuators, int sumOfDiscreteBranchSizes, int numBranches)
{
m_Actuators = actuators;
m_SumOfDiscreteBranchSizes = sumOfDiscreteBranchSizes;
m_NumBranches = numBranches;
}
/// <inheritdoc cref="IDiscreteActionMask.WriteMask"/>
public void WriteMask(int branch, IEnumerable<int> actionIndices)
{
LazyInitialize();
// Perform the masking
foreach (var actionIndex in actionIndices)
{
#if DEBUG
if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch])
{
throw new UnityAgentsException(
"Invalid Action Masking: Action Mask is too large for specified branch.");
}
#endif
m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = true;
}
}
void LazyInitialize()
{
if (m_BranchSizes == null)
{
m_BranchSizes = new int[m_NumBranches];
var start = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
var branchSizes = actuator.ActionSpec.BranchSizes;
Array.Copy(branchSizes, 0, m_BranchSizes, start, branchSizes.Length);
start += branchSizes.Length;
}
}
// By default, the masks are null. If we want to specify a new mask, we initialize
// the actionMasks with trues.
if (m_CurrentMask == null)
{
m_CurrentMask = new bool[m_SumOfDiscreteBranchSizes];
}
// If this is the first time the masked actions are used, we generate the starting
// indices for each branch.
if (m_StartingActionIndices == null)
{
m_StartingActionIndices = Utilities.CumSum(m_BranchSizes);
}
}
/// <inheritdoc cref="IDiscreteActionMask.GetMask"/>
public bool[] GetMask()
{
#if DEBUG
if (m_CurrentMask != null)
{
AssertMask();
}
#endif
return m_CurrentMask;
}
/// <summary>
/// Makes sure that the current mask is usable.
/// </summary>
void AssertMask()
{
#if DEBUG
for (var branchIndex = 0; branchIndex < m_NumBranches; branchIndex++)
{
if (AreAllActionsMasked(branchIndex))
{
throw new UnityAgentsException(
"Invalid Action Masking : All the actions of branch " + branchIndex +
" are masked.");
}
}
#endif
}
/// <summary>
/// Resets the current mask for an agent.
/// </summary>
public void ResetMask()
{
if (m_CurrentMask != null)
{
Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length);
}
}
/// <summary>
/// Checks if all the actions in the input branch are masked.
/// </summary>
/// <param name="branch"> The index of the branch to check.</param>
/// <returns> True if all the actions of the branch are masked.</returns>
bool AreAllActionsMasked(int branch)
{
if (m_CurrentMask == null)
{
return false;
}
var start = m_StartingActionIndices[branch];
var end = m_StartingActionIndices[branch + 1];
for (var i = start; i < end; i++)
{
if (!m_CurrentMask[i])
{
return false;
}
}
return true;
}
}
}

3
com.unity.ml-agents/Runtime/Actuators/ActuatorDiscreteActionMask.cs.meta


fileFormatVersion: 2
guid: d2a19e2f43fd4637a38d42b2a5f989f3
timeCreated: 1595459316

415
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs


using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// A class that manages the delegation of events, action buffers, and action mask for a list of IActuators.
/// </summary>
internal class ActuatorManager : IList<IActuator>
{
// IActuators managed by this object.
IList<IActuator> m_Actuators;
// An implementation of IDiscreteActionMask that allows for writing to it based on an offset.
ActuatorDiscreteActionMask m_DiscreteActionMask;
/// <summary>
/// Flag used to check if our IActuators are ready for execution.
/// </summary>
/// <seealso cref="ReadyActuatorsForExecution(IList{IActuator}, int, int, int)"/>
bool m_ReadyForExecution;
/// <summary>
/// The sum of all of the discrete branches for all of the <see cref="IActuator"/>s in this manager.
/// </summary>
internal int SumOfDiscreteBranchSizes { get; private set; }
/// <summary>
/// The number of the discrete branches for all of the <see cref="IActuator"/>s in this manager.
/// </summary>
internal int NumDiscreteActions { get; private set; }
/// <summary>
/// The number of continuous actions for all of the <see cref="IActuator"/>s in this manager.
/// </summary>
internal int NumContinuousActions { get; private set; }
/// <summary>
/// Returns the total actions which is calculated by <see cref="NumContinuousActions"/> + <see cref="NumDiscreteActions"/>.
/// </summary>
public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions;
/// <summary>
/// Gets the <see cref="IDiscreteActionMask"/> managed by this object.
/// </summary>
public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask;
/// <summary>
/// Returns the previously stored actions for the actuators in this list.
/// </summary>
public float[] StoredContinuousActions { get; private set; }
/// <summary>
/// Returns the previously stored actions for the actuators in this list.
/// </summary>
public int[] StoredDiscreteActions { get; private set; }
/// <summary>
/// Create an ActuatorList with a preset capacity.
/// </summary>
/// <param name="capacity">The capacity of the list to create.</param>
public ActuatorManager(int capacity = 0)
{
m_Actuators = new List<IActuator>(capacity);
}
/// <summary>
/// <see cref="ReadyActuatorsForExecution(IList{IActuator}, int, int, int)"/>
/// </summary>
void ReadyActuatorsForExecution()
{
ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes,
NumDiscreteActions);
}
/// <summary>
/// This method validates that all <see cref="IActuator"/>s have unique names and equivalent action space types
/// if the `DEBUG` preprocessor macro is defined, and allocates the appropriate buffers to manage the actions for
/// all of the <see cref="IActuator"/>s that may live on a particular object.
/// </summary>
/// <param name="actuators">The list of actuators to validate and allocate buffers for.</param>
/// <param name="numContinuousActions">The total number of continuous actions for all of the actuators.</param>
/// <param name="sumOfDiscreteBranches">The total sum of the discrete branches for all of the actuators in order
/// to be able to allocate an <see cref="IDiscreteActionMask"/>.</param>
/// <param name="numDiscreteBranches">The number of discrete branches for all of the actuators.</param>
internal void ReadyActuatorsForExecution(IList<IActuator> actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches)
{
if (m_ReadyForExecution)
{
return;
}
#if DEBUG
// Make sure the names are actually unique
// Make sure all Actuators have the same SpaceType
ValidateActuators();
#endif
// Sort the Actuators by name to ensure determinism
SortActuators();
StoredContinuousActions = numContinuousActions == 0 ? Array.Empty<float>() : new float[numContinuousActions];
StoredDiscreteActions = numDiscreteBranches == 0 ? Array.Empty<int>() : new int[numDiscreteBranches];
m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches);
m_ReadyForExecution = true;
}
/// <summary>
/// Updates the local action buffer with the action buffer passed in. If the buffer
/// passed in is null, the local action buffer will be cleared.
/// </summary>
/// <param name="continuousActionBuffer">The action buffer which contains all of the
/// continuous actions for the IActuators in this list.</param>
/// <param name="discreteActionBuffer">The action buffer which contains all of the
/// discrete actions for the IActuators in this list.</param>
public void UpdateActions(float[] continuousActionBuffer, int[] discreteActionBuffer)
{
ReadyActuatorsForExecution();
UpdateActionArray(continuousActionBuffer, StoredContinuousActions);
UpdateActionArray(discreteActionBuffer, StoredDiscreteActions);
}
static void UpdateActionArray<T>(T[] sourceActionBuffer, T[] destination)
{
if (sourceActionBuffer == null || sourceActionBuffer.Length == 0)
{
Array.Clear(destination, 0, destination.Length);
}
else
{
Debug.Assert(sourceActionBuffer.Length == destination.Length,
$"sourceActionBuffer:{sourceActionBuffer.Length} is a different" +
$" size than destination: {destination.Length}.");
Array.Copy(sourceActionBuffer, destination, destination.Length);
}
}
/// <summary>
/// This method will trigger the writing to the <see cref="IDiscreteActionMask"/> by all of the actuators
/// managed by this object.
/// </summary>
public void WriteActionMask()
{
ReadyActuatorsForExecution();
m_DiscreteActionMask.ResetMask();
var offset = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
m_DiscreteActionMask.CurrentBranchOffset = offset;
actuator.WriteDiscreteActionMask(m_DiscreteActionMask);
offset += actuator.ActionSpec.NumDiscreteActions;
}
}
/// <summary>
/// Iterates through all of the IActuators in this list and calls their
/// <see cref="IActionReceiver.OnActionReceived"/> method on them with the appropriate
/// <see cref="ActionSegment{T}"/>s depending on their <see cref="IActionReceiver.ActionSpec"/>.
/// </summary>
public void ExecuteActions()
{
ReadyActuatorsForExecution();
var continuousStart = 0;
var discreteStart = 0;
for (var i = 0; i < m_Actuators.Count; i++)
{
var actuator = m_Actuators[i];
var numContinuousActions = actuator.ActionSpec.NumContinuousActions;
var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions;
var continuousActions = ActionSegment<float>.Empty;
if (numContinuousActions > 0)
{
continuousActions = new ActionSegment<float>(StoredContinuousActions,
continuousStart,
numContinuousActions);
}
var discreteActions = ActionSegment<int>.Empty;
if (numDiscreteActions > 0)
{
discreteActions = new ActionSegment<int>(StoredDiscreteActions,
discreteStart,
numDiscreteActions);
}
actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions));
continuousStart += numContinuousActions;
discreteStart += numDiscreteActions;
}
}
/// <summary>
/// Resets the <see cref="StoredContinuousActions"/> and <see cref="StoredDiscreteActions"/> buffers to be all
/// zeros and calls <see cref="IActuator.ResetData"/> on each <see cref="IActuator"/> managed by this object.
/// </summary>
public void ResetData()
{
if (!m_ReadyForExecution)
{
return;
}
Array.Clear(StoredContinuousActions, 0, StoredContinuousActions.Length);
Array.Clear(StoredDiscreteActions, 0, StoredDiscreteActions.Length);
for (var i = 0; i < m_Actuators.Count; i++)
{
m_Actuators[i].ResetData();
}
}
/// <summary>
/// Sorts the <see cref="IActuator"/>s according to their <see cref="IActuator.GetName"/> value.
/// </summary>
void SortActuators()
{
((List<IActuator>)m_Actuators).Sort((x,
y) => x.Name
.CompareTo(y.Name));
}
/// <summary>
/// Validates that the IActuators managed by this object have unique names and equivalent action space types.
/// Each Actuator needs to have a unique name in order for this object to ensure that the storage of action
/// buffers, and execution of Actuators remains deterministic across different sessions of running.
/// </summary>
void ValidateActuators()
{
for (var i = 0; i < m_Actuators.Count - 1; i++)
{
Debug.Assert(
!m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name),
"Actuator names must be unique.");
var first = m_Actuators[i].ActionSpec;
var second = m_Actuators[i + 1].ActionSpec;
Debug.Assert(first.NumContinuousActions > 0 == second.NumContinuousActions > 0,
"Actuators on the same Agent must have the same action SpaceType.");
}
}
/// <summary>
/// Helper method to update bookkeeping items around buffer management for actuators added to this object.
/// </summary>
/// <param name="actuatorItem">The IActuator to keep bookkeeping for.</param>
void AddToBufferSizes(IActuator actuatorItem)
{
if (actuatorItem == null)
{
return;
}
NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions;
NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions;
SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes;
}
/// <summary>
/// Helper method to update bookkeeping items around buffer management for actuators removed from this object.
/// </summary>
/// <param name="actuatorItem">The IActuator to keep bookkeeping for.</param>
void SubtractFromBufferSize(IActuator actuatorItem)
{
if (actuatorItem == null)
{
return;
}
NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions;
NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions;
SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes;
}
/// <summary>
/// Sets all of the bookkeeping items back to 0.
/// </summary>
void ClearBufferSizes()
{
NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0;
}
/*********************************************************************************
* IList implementation that delegates to m_Actuators List. *
*********************************************************************************/
/// <summary>
/// <inheritdoc cref="IEnumerable{T}.GetEnumerator"/>
/// </summary>
public IEnumerator<IActuator> GetEnumerator()
{
return m_Actuators.GetEnumerator();
}
/// <summary>
/// <inheritdoc cref="IList{T}.GetEnumerator"/>
/// </summary>
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)m_Actuators).GetEnumerator();
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.Add"/>
/// </summary>
/// <param name="item"></param>
public void Add(IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot add to the ActuatorManager after its buffers have been initialized");
m_Actuators.Add(item);
AddToBufferSizes(item);
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.Clear"/>
/// </summary>
public void Clear()
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot clear the ActuatorManager after its buffers have been initialized");
m_Actuators.Clear();
ClearBufferSizes();
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.Contains"/>
/// </summary>
public bool Contains(IActuator item)
{
return m_Actuators.Contains(item);
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.CopyTo"/>
/// </summary>
public void CopyTo(IActuator[] array, int arrayIndex)
{
m_Actuators.CopyTo(array, arrayIndex);
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.Remove"/>
/// </summary>
public bool Remove(IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot remove from the ActuatorManager after its buffers have been initialized");
if (m_Actuators.Remove(item))
{
SubtractFromBufferSize(item);
return true;
}
return false;
}
/// <summary>
/// <inheritdoc cref="ICollection{T}.Count"/>
/// </summary>
public int Count => m_Actuators.Count;
/// <summary>
/// <inheritdoc cref="ICollection{T}.IsReadOnly"/>
/// </summary>
public bool IsReadOnly => m_Actuators.IsReadOnly;
/// <summary>
/// <inheritdoc cref="IList{T}.IndexOf"/>
/// </summary>
public int IndexOf(IActuator item)
{
return m_Actuators.IndexOf(item);
}
/// <summary>
/// <inheritdoc cref="IList{T}.Insert"/>
/// </summary>
public void Insert(int index, IActuator item)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot insert into the ActuatorManager after its buffers have been initialized");
m_Actuators.Insert(index, item);
AddToBufferSizes(item);
}
/// <summary>
/// <inheritdoc cref="IList{T}.RemoveAt"/>
/// </summary>
public void RemoveAt(int index)
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot remove from the ActuatorManager after its buffers have been initialized");
var actuator = m_Actuators[index];
SubtractFromBufferSize(actuator);
m_Actuators.RemoveAt(index);
}
/// <summary>
/// <inheritdoc cref="IList{T}.this"/>
/// </summary>
public IActuator this[int index]
{
get => m_Actuators[index];
set
{
Debug.Assert(m_ReadyForExecution == false,
"Cannot modify the ActuatorManager after its buffers have been initialized");
var old = m_Actuators[index];
SubtractFromBufferSize(old);
m_Actuators[index] = value;
AddToBufferSizes(value);
}
}
}
}

3
com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs.meta


fileFormatVersion: 2
guid: 7bb5b1e3779d4342a8e70f6e3c1d67cc
timeCreated: 1593031463

101
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs


using System;
using System.Linq;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// A structure that wraps the <see cref="ActionSegment{T}"/>s for a particular <see cref="IActionReceiver"/> and is
/// used when <see cref="IActionReceiver.OnActionReceived"/> is called.
/// </summary>
internal readonly struct ActionBuffers
{
/// <summary>
/// An empty action buffer.
/// </summary>
public static ActionBuffers Empty = new ActionBuffers(ActionSegment<float>.Empty, ActionSegment<int>.Empty);
/// <summary>
/// Holds the Continuous <see cref="ActionSegment{T}"/> to be used by an <see cref="IActionReceiver"/>.
/// </summary>
public ActionSegment<float> ContinuousActions { get; }
/// <summary>
/// Holds the Discrete <see cref="ActionSegment{T}"/> to be used by an <see cref="IActionReceiver"/>.
/// </summary>
public ActionSegment<int> DiscreteActions { get; }
/// <summary>
/// Construct an <see cref="ActionBuffers"/> instance with the continuous and discrete actions that will
/// be used.
/// </summary>
/// <param name="continuousActions">The continuous actions to send to an <see cref="IActionReceiver"/>.</param>
/// <param name="discreteActions">The discrete actions to send to an <see cref="IActionReceiver"/>.</param>
public ActionBuffers(ActionSegment<float> continuousActions, ActionSegment<int> discreteActions)
{
ContinuousActions = continuousActions;
DiscreteActions = discreteActions;
}
/// <inheritdoc cref="ValueType.Equals(object)"/>
public override bool Equals(object obj)
{
if (!(obj is ActionBuffers))
{
return false;
}
var ab = (ActionBuffers)obj;
return ab.ContinuousActions.SequenceEqual(ContinuousActions) &&
ab.DiscreteActions.SequenceEqual(DiscreteActions);
}
/// <inheritdoc cref="ValueType.GetHashCode"/>
public override int GetHashCode()
{
unchecked
{
return (ContinuousActions.GetHashCode() * 397) ^ DiscreteActions.GetHashCode();
}
}
}
/// <summary>
/// An interface that describes an object that can receive actions from a Reinforcement Learning network.
/// </summary>
internal interface IActionReceiver
{
/// <summary>
/// The specification of the Action space for this IActionReceiver.
/// </summary>
/// <seealso cref="ActionSpec"/>
ActionSpec ActionSpec { get; }
/// <summary>
/// Method called in order too allow object to execute actions based on the
/// <see cref="ActionBuffers"/> contents. The structure of the contents in the <see cref="ActionBuffers"/>
/// are defined by the <see cref="ActionSpec"/>.
/// </summary>
/// <param name="actionBuffers">The data structure containing the action buffers for this object.</param>
void OnActionReceived(ActionBuffers actionBuffers);
/// <summary>
/// Implement `WriteDiscreteActionMask()` to modify the masks for discrete
/// actions. When using discrete actions, the agent will not perform the masked
/// action.
/// </summary>
/// <param name="actionMask">
/// The action mask for the agent.
/// </param>
/// <remarks>
/// When using Discrete Control, you can prevent the Agent from using a certain
/// action by masking it with <see cref="IDiscreteActionMask.WriteMask"/>.
///
/// See [Agents - Actions] for more information on masking actions.
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_4_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <seealso cref="IActionReceiver.OnActionReceived"/>
void WriteDiscreteActionMask(IDiscreteActionMask actionMask);
}
}

3
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs.meta


fileFormatVersion: 2
guid: b25a5b3027c9476ea1a310241be0f10f
timeCreated: 1594756775

21
com.unity.ml-agents/Runtime/Actuators/IActuator.cs


using System;
using UnityEngine;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Abstraction that facilitates the execution of actions.
/// </summary>
internal interface IActuator : IActionReceiver
{
int TotalNumberOfActions { get; }
/// <summary>
/// Gets the name of this IActuator which will be used to sort it.
/// </summary>
/// <returns></returns>
string Name { get; }
void ResetData();
}
}

3
com.unity.ml-agents/Runtime/Actuators/IActuator.cs.meta


fileFormatVersion: 2
guid: 780d7f0a675f44bfa784b370025b51c3
timeCreated: 1592848317

38
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs


using System.Collections.Generic;
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Interface for writing a mask to disable discrete actions for agents for the next decision.
/// </summary>
internal interface IDiscreteActionMask
{
/// <summary>
/// Modifies an action mask for discrete control agents.
/// </summary>
/// <remarks>
/// When used, the agent will not be able to perform the actions passed as argument
/// at the next decision for the specified action branch. The actionIndices correspond
/// to the action options the agent will be unable to perform.
///
/// See [Agents - Actions] for more information on masking actions.
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_2_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <param name="branch">The branch for which the actions will be masked.</param>
/// <param name="actionIndices">The indices of the masked actions.</param>
void WriteMask(int branch, IEnumerable<int> actionIndices);
/// <summary>
/// Get the current mask for an agent.
/// </summary>
/// <returns>A mask for the agent. A boolean array of length equal to the total number of
/// actions.</returns>
bool[] GetMask();
/// <summary>
/// Resets the current mask for an agent.
/// </summary>
void ResetMask();
}
}

3
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs.meta


fileFormatVersion: 2
guid: 1bc4e4b71bf4470789488fab2ee65388
timeCreated: 1595369065

72
com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs


using System;
using Unity.MLAgents.Policies;
namespace Unity.MLAgents.Actuators
{
internal class VectorActuator : IActuator
{
IActionReceiver m_ActionReceiver;
ActionBuffers m_ActionBuffers;
internal ActionBuffers ActionBuffers
{
get => m_ActionBuffers;
private set => m_ActionBuffers = value;
}
public VectorActuator(IActionReceiver actionReceiver,
int[] vectorActionSize,
SpaceType spaceType,
string name = "VectorActuator")
{
m_ActionReceiver = actionReceiver;
string suffix;
switch (spaceType)
{
case SpaceType.Continuous:
ActionSpec = ActionSpec.MakeContinuous(vectorActionSize[0]);
suffix = "-Continuous";
break;
case SpaceType.Discrete:
ActionSpec = ActionSpec.MakeDiscrete(vectorActionSize);
suffix = "-Discrete";
break;
default:
throw new ArgumentOutOfRangeException(nameof(spaceType),
spaceType,
"Unknown enum value.");
}
Name = name + suffix;
}
public void ResetData()
{
m_ActionBuffers = ActionBuffers.Empty;
}
public void OnActionReceived(ActionBuffers actionBuffers)
{
ActionBuffers = actionBuffers;
m_ActionReceiver.OnActionReceived(ActionBuffers);
}
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
m_ActionReceiver.WriteDiscreteActionMask(actionMask);
}
/// <summary>
/// Returns the number of discrete branches + the number of continuous actions.
/// </summary>
public int TotalNumberOfActions => ActionSpec.NumContinuousActions +
ActionSpec.NumDiscreteActions;
/// <summary>
/// <inheritdoc cref="IActionReceiver.ActionSpec"/>
/// </summary>
public ActionSpec ActionSpec { get; }
public string Name { get; }
}
}

3
com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs.meta


fileFormatVersion: 2
guid: ff7a3292c0b24b23b3f1c0eeb690ec4c
timeCreated: 1593023833

55
com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs


using System;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
namespace Unity.MLAgents.Tests.Actuators
{
[TestFixture]
public class ActionSegmentTests
{
[Test]
public void TestConstruction()
{
var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f };
Assert.Throws<ArgumentOutOfRangeException>(
() => new ActionSegment<float>(floatArray, 100, 1));
var segment = new ActionSegment<float>(Array.Empty<float>(), 0, 0);
Assert.AreEqual(segment, ActionSegment<float>.Empty);
}
[Test]
public void TestIndexing()
{
var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f };
for (var i = 0; i < floatArray.Length; i++)
{
var start = 0 + i;
var length = floatArray.Length - i;
var actionSegment = new ActionSegment<float>(floatArray, start, length);
for (var j = 0; j < actionSegment.Length; j++)
{
Assert.AreEqual(actionSegment[j], floatArray[start + j]);
}
}
}
[Test]
public void TestEnumerator()
{
var floatArray = new[] { 1f, 2f, 3f, 4f, 5f, 6f, 7f };
for (var i = 0; i < floatArray.Length; i++)
{
var start = 0 + i;
var length = floatArray.Length - i;
var actionSegment = new ActionSegment<float>(floatArray, start, length);
var j = 0;
foreach (var item in actionSegment)
{
Assert.AreEqual(item, floatArray[start + j++]);
}
}
}
}
}

3
com.unity.ml-agents/Tests/Editor/Actuators/ActionSegmentTests.cs.meta


fileFormatVersion: 2
guid: 18cb6d052fba43a2b7437d87c0d9abad
timeCreated: 1596486604

114
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs


using System;
using System.Collections.Generic;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
namespace Unity.MLAgents.Tests.Actuators
{
[TestFixture]
public class ActuatorDiscreteActionMaskTests
{
[Test]
public void Construction()
{
var masker = new ActuatorDiscreteActionMask(new List<IActuator>(), 0, 0);
Assert.IsNotNull(masker);
}
[Test]
public void NullMask()
{
var masker = new ActuatorDiscreteActionMask(new List<IActuator>(), 0, 0);
var mask = masker.GetMask();
Assert.IsNull(mask);
}
[Test]
public void FirstBranchMask()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
var mask = masker.GetMask();
Assert.IsNull(mask);
masker.WriteMask(0, new[] {1, 2, 3});
mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsTrue(mask[1]);
Assert.IsTrue(mask[2]);
Assert.IsTrue(mask[3]);
Assert.IsFalse(mask[4]);
Assert.AreEqual(mask.Length, 15);
}
[Test]
public void SecondBranchMask()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
var masker = new ActuatorDiscreteActionMask(new[] {actuator1}, 15, 3);
masker.WriteMask(1, new[] {1, 2, 3});
var mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsFalse(mask[4]);
Assert.IsTrue(mask[5]);
Assert.IsTrue(mask[6]);
Assert.IsTrue(mask[7]);
Assert.IsFalse(mask[8]);
Assert.IsFalse(mask[9]);
}
[Test]
public void MaskReset()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
masker.WriteMask(1, new[] {1, 2, 3});
masker.ResetMask();
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{
Assert.IsFalse(mask[i]);
}
}
[Test]
public void ThrowsError()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
Assert.Catch<UnityAgentsException>(
() => masker.WriteMask(0, new[] {5}));
Assert.Catch<UnityAgentsException>(
() => masker.WriteMask(1, new[] {5}));
masker.WriteMask(2, new[] {5});
Assert.Catch<UnityAgentsException>(
() => masker.WriteMask(3, new[] {1}));
masker.GetMask();
masker.ResetMask();
masker.WriteMask(0, new[] {0, 1, 2, 3});
Assert.Catch<UnityAgentsException>(
() => masker.GetMask());
}
[Test]
public void MultipleMaskEdit()
{
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {4, 5, 6}), "actuator1");
var masker = new ActuatorDiscreteActionMask(new IActuator[] {actuator1}, 15, 3);
masker.WriteMask(0, new[] {0, 1});
masker.WriteMask(0, new[] {3});
masker.WriteMask(2, new[] {1});
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{
if ((i == 0) || (i == 1) || (i == 3) || (i == 10))
{
Assert.IsTrue(mask[i]);
}
else
{
Assert.IsFalse(mask[i]);
}
}
}
}
}

3
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorDiscreteActionMaskTests.cs.meta


fileFormatVersion: 2
guid: b9f5f87049d04d8bba39d193a3ab2f5a
timeCreated: 1596491682

310
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs


using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using UnityEngine;
using UnityEngine.TestTools;
using Assert = UnityEngine.Assertions.Assert;
namespace Unity.MLAgents.Tests.Actuators
{
[TestFixture]
public class ActuatorManagerTests
{
[Test]
public void TestEnsureBufferSizeContinuous()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(10), "actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(2), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var actuator1ActionSpaceDef = actuator1.ActionSpec;
var actuator2ActionSpaceDef = actuator2.ActionSpec;
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 },
actuator1ActionSpaceDef.NumContinuousActions + actuator2ActionSpaceDef.NumContinuousActions,
actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes,
actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions);
manager.UpdateActions(new[]
{ 0f, 1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f }, Array.Empty<int>());
Assert.IsTrue(12 == manager.NumContinuousActions);
Assert.IsTrue(0 == manager.NumDiscreteActions);
Assert.IsTrue(0 == manager.SumOfDiscreteBranchSizes);
Assert.IsTrue(12 == manager.StoredContinuousActions.Length);
Assert.IsTrue(0 == manager.StoredDiscreteActions.Length);
}
[Test]
public void TestEnsureBufferDiscrete()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 1, 1}), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var actuator1ActionSpaceDef = actuator1.ActionSpec;
var actuator2ActionSpaceDef = actuator2.ActionSpec;
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 },
actuator1ActionSpaceDef.NumContinuousActions + actuator2ActionSpaceDef.NumContinuousActions,
actuator1ActionSpaceDef.SumOfDiscreteBranchSizes + actuator2ActionSpaceDef.SumOfDiscreteBranchSizes,
actuator1ActionSpaceDef.NumDiscreteActions + actuator2ActionSpaceDef.NumDiscreteActions);
manager.UpdateActions(Array.Empty<float>(),
new[] { 0, 1, 2, 3, 4, 5, 6});
Assert.IsTrue(0 == manager.NumContinuousActions);
Assert.IsTrue(7 == manager.NumDiscreteActions);
Assert.IsTrue(13 == manager.SumOfDiscreteBranchSizes);
Assert.IsTrue(0 == manager.StoredContinuousActions.Length);
Assert.IsTrue(7 == manager.StoredDiscreteActions.Length);
}
[Test]
public void TestFailOnMixedActionSpace()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, 3, 10, 4);
LogAssert.Expect(LogType.Assert, "Actuators on the same Agent must have the same action SpaceType.");
}
[Test]
public void TestFailOnSameActuatorName()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator1");
manager.Add(actuator1);
manager.Add(actuator2);
manager.ReadyActuatorsForExecution(new[] { actuator1, actuator2 }, 3, 10, 4);
LogAssert.Expect(LogType.Assert, "Actuator names must be unique.");
}
[Test]
public void TestExecuteActionsDiscrete()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1 , 2, 3, 4}), "actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 1, 1}), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5, 6};
manager.UpdateActions(Array.Empty<float>(),
discreteActionBuffer);
manager.ExecuteActions();
var actuator1Actions = actuator1.LastActionBuffer.DiscreteActions;
var actuator2Actions = actuator2.LastActionBuffer.DiscreteActions;
TestSegmentEquality(actuator1Actions, discreteActionBuffer); TestSegmentEquality(actuator2Actions, discreteActionBuffer);
}
[Test]
public void TestExecuteActionsContinuous()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
manager.UpdateActions(continuousActionBuffer,
Array.Empty<int>());
manager.ExecuteActions();
var actuator1Actions = actuator1.LastActionBuffer.ContinuousActions;
var actuator2Actions = actuator2.LastActionBuffer.ContinuousActions;
TestSegmentEquality(actuator1Actions, continuousActionBuffer);
TestSegmentEquality(actuator2Actions, continuousActionBuffer);
}
static void TestSegmentEquality<T>(ActionSegment<T> actionSegment, T[] actionBuffer)
where T : struct
{
Assert.IsFalse(actionSegment.Length == 0);
for (var i = 0; i < actionSegment.Length; i++)
{
var action = actionSegment[i];
Assert.AreEqual(action, actionBuffer[actionSegment.Offset + i]);
}
}
[Test]
public void TestUpdateActionsContinuous()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
manager.UpdateActions(continuousActionBuffer,
Array.Empty<int>());
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer));
}
[Test]
public void TestUpdateActionsDiscrete()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var discreteActionBuffer = new[] { 0, 1, 2, 3, 4, 5};
manager.UpdateActions(Array.Empty<float>(),
discreteActionBuffer);
Debug.Log(manager.StoredDiscreteActions);
Debug.Log(discreteActionBuffer);
Assert.IsTrue(manager.StoredDiscreteActions.SequenceEqual(discreteActionBuffer));
}
[Test]
public void TestRemove()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
Assert.IsTrue(manager.NumDiscreteActions == 6);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 12);
manager.Remove(actuator2);
Assert.IsTrue(manager.NumDiscreteActions == 3);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6);
manager.Remove(null);
Assert.IsTrue(manager.NumDiscreteActions == 3);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6);
manager.RemoveAt(0);
Assert.IsTrue(manager.NumDiscreteActions == 0);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 0);
}
[Test]
public void TestClear()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3 }),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
Assert.IsTrue(manager.NumDiscreteActions == 6);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 12);
manager.Clear();
Assert.IsTrue(manager.NumDiscreteActions == 0);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 0);
}
[Test]
public void TestIndexSet()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4}),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
manager.Add(actuator1);
Assert.IsTrue(manager.NumDiscreteActions == 4);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 10);
manager[0] = actuator2;
Assert.IsTrue(manager.NumDiscreteActions == 3);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 6);
}
[Test]
public void TestInsert()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeDiscrete(new[] { 1, 2, 3, 4}),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "actuator2");
manager.Add(actuator1);
Assert.IsTrue(manager.NumDiscreteActions == 4);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 10);
manager.Insert(0, actuator2);
Assert.IsTrue(manager.NumDiscreteActions == 7);
Assert.IsTrue(manager.SumOfDiscreteBranchSizes == 16);
Assert.IsTrue(manager.IndexOf(actuator2) == 0);
}
[Test]
public void TestResetData()
{
var manager = new ActuatorManager();
var actuator1 = new TestActuator(ActionSpec.MakeContinuous(3),
"actuator1");
var actuator2 = new TestActuator(ActionSpec.MakeContinuous(3), "actuator2");
manager.Add(actuator1);
manager.Add(actuator2);
var continuousActionBuffer = new[] { 0f, 1f, 2f, 3f, 4f, 5f};
manager.UpdateActions(continuousActionBuffer,
Array.Empty<int>());
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(continuousActionBuffer));
Assert.IsTrue(manager.NumContinuousActions == 6);
manager.ResetData();
Assert.IsTrue(manager.StoredContinuousActions.SequenceEqual(new[] { 0f, 0f, 0f, 0f, 0f, 0f}));
}
[Test]
public void TestWriteDiscreteActionMask()
{
var manager = new ActuatorManager(2);
var va1 = new TestActuator(ActionSpec.MakeDiscrete(new[] {1, 2, 3}), "name");
var va2 = new TestActuator(ActionSpec.MakeDiscrete(new[] {3, 2, 1}), "name1");
manager.Add(va1);
manager.Add(va2);
var groundTruthMask = new[]
{
false,
true, false,
false, true, true,
true, false, true,
false, true,
false
};
va1.Masks = new[]
{
Array.Empty<int>(),
new[] { 0 },
new[] { 1, 2 }
};
va2.Masks = new[]
{
new[] {0, 2},
new[] {1},
Array.Empty<int>()
};
manager.WriteActionMask();
Assert.IsTrue(groundTruthMask.SequenceEqual(manager.DiscreteActionMask.GetMask()));
}
}
}

3
com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs.meta


fileFormatVersion: 2
guid: d48ba72f0ac64d7db0af22c9d82b11d8
timeCreated: 1596494279

38
com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs


using Unity.MLAgents.Actuators;
namespace Unity.MLAgents.Tests.Actuators
{
internal class TestActuator : IActuator
{
public ActionBuffers LastActionBuffer;
public int[][] Masks;
public TestActuator(ActionSpec actuatorSpace, string name)
{
ActionSpec = actuatorSpace;
TotalNumberOfActions = actuatorSpace.NumContinuousActions +
actuatorSpace.NumDiscreteActions;
Name = name;
}
public void OnActionReceived(ActionBuffers actionBuffers)
{
LastActionBuffer = actionBuffers;
}
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
for (var i = 0; i < Masks.Length; i++)
{
actionMask.WriteMask(i, Masks[i]);
}
}
public int TotalNumberOfActions { get; }
public ActionSpec ActionSpec { get; }
public string Name { get; }
public void ResetData()
{
}
}
}

3
com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs.meta


fileFormatVersion: 2
guid: fa950d7b175749bfa287fd8761dd831f
timeCreated: 1596665978

98
com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs


using System.Collections.Generic;
using System.Linq;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using Assert = UnityEngine.Assertions.Assert;
namespace Unity.MLAgents.Tests.Actuators
{
[TestFixture]
public class VectorActuatorTests
{
class TestActionReceiver : IActionReceiver
{
public ActionBuffers LastActionBuffers;
public int Branch;
public IList<int> Mask;
public ActionSpec ActionSpec { get; }
public void OnActionReceived(ActionBuffers actionBuffers)
{
LastActionBuffers = actionBuffers;
}
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
actionMask.WriteMask(Branch, Mask);
}
}
[Test]
public void TestConstruct()
{
var ar = new TestActionReceiver();
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
Assert.IsTrue(va.ActionSpec.NumDiscreteActions == 3);
Assert.IsTrue(va.ActionSpec.SumOfDiscreteBranchSizes == 6);
Assert.IsTrue(va.ActionSpec.NumContinuousActions == 0);
var va1 = new VectorActuator(ar, new[] {4}, SpaceType.Continuous, "name");
Assert.IsTrue(va1.ActionSpec.NumContinuousActions == 4);
Assert.IsTrue(va1.ActionSpec.SumOfDiscreteBranchSizes == 0);
Assert.AreEqual(va1.Name, "name-Continuous");
}
[Test]
public void TestOnActionReceived()
{
var ar = new TestActionReceiver();
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
var discreteActions = new[] { 0, 1, 1 };
var ab = new ActionBuffers(ActionSegment<float>.Empty,
new ActionSegment<int>(discreteActions, 0, 3));
va.OnActionReceived(ab);
Assert.AreEqual(ar.LastActionBuffers, ab);
va.ResetData();
Assert.AreEqual(va.ActionBuffers.ContinuousActions, ActionSegment<float>.Empty);
Assert.AreEqual(va.ActionBuffers.DiscreteActions, ActionSegment<int>.Empty);
}
[Test]
public void TestResetData()
{
var ar = new TestActionReceiver();
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
var discreteActions = new[] { 0, 1, 1 };
var ab = new ActionBuffers(ActionSegment<float>.Empty,
new ActionSegment<int>(discreteActions, 0, 3));
va.OnActionReceived(ab);
}
[Test]
public void TestWriteDiscreteActionMask()
{
var ar = new TestActionReceiver();
var va = new VectorActuator(ar, new[] {1, 2, 3}, SpaceType.Discrete, "name");
var bdam = new ActuatorDiscreteActionMask(new[] {va}, 6, 3);
var groundTruthMask = new[] { false, true, false, false, true, true };
ar.Branch = 1;
ar.Mask = new[] { 0 };
va.WriteDiscreteActionMask(bdam);
ar.Branch = 2;
ar.Mask = new[] { 1, 2 };
va.WriteDiscreteActionMask(bdam);
Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask()));
}
}
}

3
com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs.meta


fileFormatVersion: 2
guid: c2b191d2929f49adab0769705d49d86a
timeCreated: 1596580289
正在加载...
取消
保存