浏览代码

Merge branch 'develop-hybrid-actions-singleton' into develop-hybrid-actions-csharp

/MLA-1734-demo-provider
Ruo-Ping Dong 4 年前
当前提交
a7d04be6
共有 105 个文件被更改,包括 6632 次插入293 次删除
  1. 10
      README.md
  2. 5
      com.unity.ml-agents.extensions/Documentation~/com.unity.ml-agents.extensions.md
  3. 7
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
  4. 2
      com.unity.ml-agents/Documentation~/com.unity.ml-agents.md
  5. 4
      com.unity.ml-agents/Runtime/Academy.cs
  6. 2
      com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs
  7. 2
      com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs
  8. 26
      com.unity.ml-agents/Runtime/Agent.cs
  9. 2
      com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs
  10. 2
      com.unity.ml-agents/Runtime/DiscreteActionMasker.cs
  11. 2
      com.unity.ml-agents/Runtime/SensorHelper.cs
  12. 4
      docs/Installation-Anaconda-Windows.md
  13. 6
      docs/Installation.md
  14. 21
      docs/Learning-Environment-Examples.md
  15. 2
      docs/Training-on-Amazon-Web-Service.md
  16. 4
      docs/Unity-Inference-Engine.md
  17. 78
      ml-agents-envs/mlagents_envs/base_env.py
  18. 26
      ml-agents/mlagents/trainers/agent_processor.py
  19. 12
      ml-agents/mlagents/trainers/env_manager.py
  20. 12
      ml-agents/mlagents/trainers/policy/policy.py
  21. 22
      ml-agents/mlagents/trainers/policy/tf_policy.py
  22. 10
      ml-agents/mlagents/trainers/policy/torch_policy.py
  23. 6
      ml-agents/mlagents/trainers/ppo/optimizer_tf.py
  24. 3
      ml-agents/mlagents/trainers/simple_env_manager.py
  25. 7
      ml-agents/mlagents/trainers/subprocess_env_manager.py
  26. 20
      ml-agents/mlagents/trainers/tests/mock_brain.py
  27. 4
      ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py
  28. 2
      ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py
  29. 27
      ml-agents/mlagents/trainers/tests/test_agent_processor.py
  30. 4
      ml-agents/mlagents/trainers/tests/test_trajectory.py
  31. 2
      ml-agents/mlagents/trainers/tests/torch/test_distributions.py
  32. 82
      ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
  33. 10
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  34. 11
      ml-agents/mlagents/trainers/tests/torch/test_ppo.py
  35. 4
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  36. 4
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  37. 14
      ml-agents/mlagents/trainers/torch/action_flattener.py
  38. 44
      ml-agents/mlagents/trainers/torch/action_log_probs.py
  39. 84
      ml-agents/mlagents/trainers/torch/action_model.py
  40. 21
      ml-agents/mlagents/trainers/torch/agent_action.py
  41. 12
      ml-agents/mlagents/trainers/torch/components/bc/module.py
  42. 2
      ml-agents/mlagents/trainers/torch/distributions.py
  43. 25
      ml-agents/mlagents/trainers/trajectory.py
  44. 1
      utils/make_readme_table.py
  45. 8
      Project/Assets/ML-Agents/Examples/Match3.meta
  46. 67
      com.unity.ml-agents.extensions/Documentation~/Match3.md
  47. 3
      com.unity.ml-agents.extensions/Runtime/Match3.meta
  48. 3
      com.unity.ml-agents.extensions/Tests/Editor/Match3.meta
  49. 75
      config/ppo/Match3.yaml
  50. 77
      docs/images/match3.png
  51. 8
      Project/Assets/ML-Agents/Examples/Match3/Prefabs.meta
  52. 174
      Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab
  53. 7
      Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab.meta
  54. 170
      Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab
  55. 7
      Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab.meta
  56. 170
      Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab
  57. 7
      Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab.meta
  58. 8
      Project/Assets/ML-Agents/Examples/Match3/Scenes.meta
  59. 1001
      Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity
  60. 7
      Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity.meta
  61. 8
      Project/Assets/ML-Agents/Examples/Match3/Scripts.meta
  62. 373
      Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs
  63. 3
      Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs.meta
  64. 272
      Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs
  65. 11
      Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs.meta
  66. 102
      Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs
  67. 3
      Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs.meta
  68. 8
      Project/Assets/ML-Agents/Examples/Match3/TFModels.meta
  69. 1001
      Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx
  70. 14
      Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx.meta
  71. 1001
      Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn
  72. 11
      Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn.meta
  73. 233
      com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs
  74. 3
      com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs.meta
  75. 120
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
  76. 3
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs.meta
  77. 49
      com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs
  78. 3
      com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs.meta
  79. 297
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs
  80. 3
      com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs.meta
  81. 43
      com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs
  82. 3
      com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs.meta
  83. 260
      com.unity.ml-agents.extensions/Runtime/Match3/Move.cs
  84. 3
      com.unity.ml-agents.extensions/Runtime/Match3/Move.cs.meta
  85. 152
      com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs
  86. 3
      com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs.meta
  87. 115
      com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs
  88. 3
      com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs.meta
  89. 314
      com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs
  90. 3
      com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs.meta
  91. 60
      com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs
  92. 3
      com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs.meta
  93. 3
      com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs0.png

10
README.md


# Unity ML-Agents Toolkit
[![docs badge](https://img.shields.io/badge/docs-reference-blue.svg)](https://github.com/Unity-Technologies/ml-agents/tree/release_8_docs/docs/)
[![docs badge](https://img.shields.io/badge/docs-reference-blue.svg)](https://github.com/Unity-Technologies/ml-agents/tree/release_9_docs/docs/)
[![license badge](https://img.shields.io/badge/license-Apache--2.0-green.svg)](LICENSE)

## Releases & Documentation
**Our latest, stable release is `Release 8`. Click
[here](https://github.com/Unity-Technologies/ml-agents/tree/release_8_docs/docs/Readme.md)
**Our latest, stable release is `Release 9`. Click
[here](https://github.com/Unity-Technologies/ml-agents/tree/release_9_docs/docs/Readme.md)
to get started with the latest release of ML-Agents.**
The table below lists all our releases, including our `master` branch which is

| **Version** | **Release Date** | **Source** | **Documentation** | **Download** |
|:-------:|:------:|:-------------:|:-------:|:------------:|
| **master (unstable)** | -- | [source](https://github.com/Unity-Technologies/ml-agents/tree/master) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/master/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/master.zip) |
| **Release 8** | **October 14, 2020** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/release_8)** | **[docs](https://github.com/Unity-Technologies/ml-agents/tree/release_8_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/release_8.zip)** |
| **Release 9** | **November 4, 2020** | **[source](https://github.com/Unity-Technologies/ml-agents/tree/release_9)** | **[docs](https://github.com/Unity-Technologies/ml-agents/tree/release_9_docs/docs/Readme.md)** | **[download](https://github.com/Unity-Technologies/ml-agents/archive/release_9.zip)** |
| **Release 8** | October 14, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_8) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_8_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_8.zip) |
| **Release 2** | May 20, 2020 | [source](https://github.com/Unity-Technologies/ml-agents/tree/release_2) | [docs](https://github.com/Unity-Technologies/ml-agents/tree/release_2_docs/docs/Readme.md) | [download](https://github.com/Unity-Technologies/ml-agents/archive/release_2.zip) |
## Citation

5
com.unity.ml-agents.extensions/Documentation~/com.unity.ml-agents.extensions.md


| _Runtime_ | Contains core C# APIs for integrating ML-Agents into your Unity scene. |
| _Tests_ | Contains the unit tests for the package. |
The Runtime directory currently contains three features:
* [Match-3 sensor and actuator](Match3.md)
* [Grid-based sensor](Grid-Sensor.md)
* Physics-based sensors
## Installation
The ML-Agents Extensions package is not currently available in the Package Manager. There are two
recommended ways to install the package:

7
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs


bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
Assert.IsTrue(isOK, errorMessage);
}
public static void CompareObservation(ISensor sensor, float[,,] expected)
{
string errorMessage;
bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
Assert.IsTrue(isOK, errorMessage);
}
}
public class RigidBodySensorTests

2
com.unity.ml-agents/Documentation~/com.unity.ml-agents.md


[unity ML-Agents Toolkit]: https://github.com/Unity-Technologies/ml-agents
[unity inference engine]: https://docs.unity3d.com/Packages/com.unity.barracuda@latest/index.html
[package manager documentation]: https://docs.unity3d.com/Manual/upm-ui-install.html
[installation instructions]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Installation.md
[installation instructions]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Installation.md
[github repository]: https://github.com/Unity-Technologies/ml-agents
[python package]: https://github.com/Unity-Technologies/ml-agents
[execution order of event functions]: https://docs.unity3d.com/Manual/ExecutionOrder.html

4
com.unity.ml-agents/Runtime/Academy.cs


* API. For more information on each of these entities, in addition to how to
* set-up a learning environment and train the behavior of characters in a
* Unity scene, please browse our documentation pages on GitHub:
* https://github.com/Unity-Technologies/ml-agents/tree/release_8_docs/docs/
* https://github.com/Unity-Technologies/ml-agents/tree/release_9_docs/docs/
*/
namespace Unity.MLAgents

/// fall back to inference or heuristic decisions. (You can also set agents to always use
/// inference or heuristics.)
/// </remarks>
[HelpURL("https://github.com/Unity-Technologies/ml-agents/tree/release_8_docs/" +
[HelpURL("https://github.com/Unity-Technologies/ml-agents/tree/release_9_docs/" +
"docs/Learning-Environment-Design.md")]
public class Academy : IDisposable
{

2
com.unity.ml-agents/Runtime/Actuators/IActionReceiver.cs


///
/// See [Agents - Actions] for more information on masking actions.
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design-Agents.md#actions
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <seealso cref="IActionReceiver.OnActionReceived"/>
void WriteDiscreteActionMask(IDiscreteActionMask actionMask);

2
com.unity.ml-agents/Runtime/Actuators/IDiscreteActionMask.cs


///
/// See [Agents - Actions] for more information on masking actions.
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design-Agents.md#actions
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <param name="branch">The branch for which the actions will be masked.</param>
/// <param name="actionIndices">The indices of the masked actions.</param>

26
com.unity.ml-agents/Runtime/Agent.cs


/// [OnDisable()]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.OnDisable.html]
/// [OnBeforeSerialize()]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.OnBeforeSerialize.html
/// [OnAfterSerialize()]: https://docs.unity3d.com/ScriptReference/MonoBehaviour.OnAfterSerialize.html
/// [Agents]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design-Agents.md
/// [Reinforcement Learning in Unity]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design.md
/// [Agents]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design-Agents.md
/// [Reinforcement Learning in Unity]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design.md
/// [Unity ML-Agents Toolkit manual]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Readme.md
/// [Unity ML-Agents Toolkit manual]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Readme.md
[HelpURL("https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/" +
[HelpURL("https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/" +
"docs/Learning-Environment-Design-Agents.md")]
[Serializable]
[RequireComponent(typeof(BehaviorParameters))]

/// for information about mixing reward signals from curiosity and Generative Adversarial
/// Imitation Learning (GAIL) with rewards supplied through this method.
///
/// [Agents - Rewards]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design-Agents.md#rewards
/// [Reward Signals]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/ML-Agents-Overview.md#a-quick-note-on-reward-signals
/// [Agents - Rewards]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design-Agents.md#rewards
/// [Reward Signals]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/ML-Agents-Overview.md#a-quick-note-on-reward-signals
/// </remarks>
/// <param name="reward">The new value of the reward.</param>
public void SetReward(float reward)

/// for information about mixing reward signals from curiosity and Generative Adversarial
/// Imitation Learning (GAIL) with rewards supplied through this method.
///
/// [Agents - Rewards]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design-Agents.md#rewards
/// [Reward Signals]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/ML-Agents-Overview.md#a-quick-note-on-reward-signals
/// [Agents - Rewards]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design-Agents.md#rewards
/// [Reward Signals]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/ML-Agents-Overview.md#a-quick-note-on-reward-signals
///</remarks>
/// <param name="increment">Incremental reward value.</param>
public void AddReward(float increment)

/// implementing a simple heuristic function can aid in debugging agent actions and interactions
/// with its environment.
///
/// [Demonstration Recorder]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design-Agents.md#recording-demonstrations
/// [Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design-Agents.md#actions
/// [Demonstration Recorder]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design-Agents.md#recording-demonstrations
/// [Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design-Agents.md#actions
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
/// </remarks>
/// <example>

/// For more information about observations, see [Observations and Sensors].
///
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
/// [Observations and Sensors]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design-Agents.md#observations-and-sensors
/// [Observations and Sensors]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design-Agents.md#observations-and-sensors
/// </remarks>
public virtual void CollectObservations(VectorSensor sensor)
{

///
/// See [Agents - Actions] for more information on masking actions.
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design-Agents.md#actions
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <seealso cref="IActionReceiver.OnActionReceived"/>
public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask)

///
/// For more information about implementing agent actions see [Agents - Actions].
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design-Agents.md#actions
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <param name="actions">
/// Struct containing the buffers of actions to be executed at this step.

2
com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs


/// See [Imitation Learning - Recording Demonstrations] for more information.
///
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
/// [Imitation Learning - Recording Demonstrations]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs//Learning-Environment-Design-Agents.md#recording-demonstrations
/// [Imitation Learning - Recording Demonstrations]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs//Learning-Environment-Design-Agents.md#recording-demonstrations
/// </remarks>
[RequireComponent(typeof(Agent))]
[AddComponentMenu("ML Agents/Demonstration Recorder", (int)MenuGroup.Default)]

2
com.unity.ml-agents/Runtime/DiscreteActionMasker.cs


///
/// See [Agents - Actions] for more information on masking actions.
///
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/docs/Learning-Environment-Design-Agents.md#actions
/// [Agents - Actions]: https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/docs/Learning-Environment-Design-Agents.md#actions
/// </remarks>
/// <param name="branch">The branch for which the actions will be masked.</param>
/// <param name="actionIndices">The indices of the masked actions.</param>

2
com.unity.ml-agents/Runtime/SensorHelper.cs


if (expected[h, w, c] != output[tensorShape.Index(0, h, w, c)])
{
errorMessage = $"Expected and actual differed in position [{h}, {w}, {c}]. " +
"Expected: {expected[h, w, c]} Actual: {output[tensorShape.Index(0, h, w, c)]} ";
$"Expected: {expected[h, w, c]} Actual: {output[tensorShape.Index(0, h, w, c)]} ";
return false;
}
}

4
docs/Installation-Anaconda-Windows.md


the ml-agents Conda environment by typing `activate ml-agents`)_:
```sh
git clone --branch release_8 https://github.com/Unity-Technologies/ml-agents.git
git clone --branch release_9 https://github.com/Unity-Technologies/ml-agents.git
The `--branch release_8` option will switch to the tag of the latest stable
The `--branch release_9` option will switch to the tag of the latest stable
release. Omitting that will get the `master` branch which is potentially
unstable.

6
docs/Installation.md


of our tutorials / guides assume you have access to our example environments).
```sh
git clone --branch release_8 https://github.com/Unity-Technologies/ml-agents.git
git clone --branch release_9 https://github.com/Unity-Technologies/ml-agents.git
The `--branch release_8` option will switch to the tag of the latest stable
The `--branch release_9` option will switch to the tag of the latest stable
release. Omitting that will get the `master` branch which is potentially
unstable.

ML-Agents Toolkit for your purposes. If you plan to contribute those changes
back, make sure to clone the `master` branch (by omitting `--branch release_8`
back, make sure to clone the `master` branch (by omitting `--branch release_9`
from the command above). See our
[Contributions Guidelines](../com.unity.ml-agents/CONTRIBUTING.md) for more
information on contributing to the ML-Agents Toolkit.

21
docs/Learning-Environment-Examples.md


does not train with the provided default training parameters.**
- Float Properties: None
- Benchmark Mean Reward: 1.75
## Match 3
![Match 3](images/match3.png)
- Set-up: Simple match-3 game. Matched pieces are removed, and remaining pieces
drop down. New pieces are spawned randomly at the top, with a chance of being
"special".
- Goal: Maximize score from matching pieces.
- Agents: The environment contains several independent Agents.
- Agent Reward Function (independent):
- .01 for each normal piece cleared. Special pieces are worth 2x or 3x.
- Behavior Parameters:
- None
- Observations and actions are defined with a sensor and actuator respectively.
- Float Properties: None
- Benchmark Mean Reward:
- 37.2 for visual observations
- 37.6 for vector observations
- 34.2 for simple heuristic (pick a random valid move)
- 37.0 for greedy heuristic (pick the highest-scoring valid move)

2
docs/Training-on-Amazon-Web-Service.md


2. Clone the ML-Agents repo and install the required Python packages
```sh
git clone --branch release_8 https://github.com/Unity-Technologies/ml-agents.git
git clone --branch release_9 https://github.com/Unity-Technologies/ml-agents.git
cd ml-agents/ml-agents/
pip3 install -e .
```

4
docs/Unity-Inference-Engine.md


loading expects certain conventions for constants and tensor names. While it is
possible to construct a model that follows these conventions, we don't provide
any additional help for this. More details can be found in
[TensorNames.cs](https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/com.unity.ml-agents/Runtime/Inference/TensorNames.cs)
[TensorNames.cs](https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/com.unity.ml-agents/Runtime/Inference/TensorNames.cs)
[BarracudaModelParamLoader.cs](https://github.com/Unity-Technologies/ml-agents/blob/release_8_docs/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs).
[BarracudaModelParamLoader.cs](https://github.com/Unity-Technologies/ml-agents/blob/release_9_docs/com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs).
If you wish to run inference on an externally trained model, you should use
Barracuda directly, instead of trying to run it through ML-Agents.

78
ml-agents-envs/mlagents_envs/base_env.py


)
class ActionTuple:
class _ActionTupleBase(ABC):
An object whose fields correspond to actions of different types.
Continuous and discrete actions are numpy arrays of type float32 and
int32, respectively and are type checked on construction.
Dimensions are of (n_agents, continuous_size) and (n_agents, discrete_size),
respectively.
An object whose fields correspond to action data of continuous and discrete
spaces. Dimensions are of (n_agents, continuous_size) and (n_agents, discrete_size),
respectively. Note, this also holds when continuous or discrete size is
zero.
"""
def __init__(

):
if continuous is not None and continuous.dtype != np.float32:
continuous = continuous.astype(np.float32, copy=False)
self._continuous = continuous
if discrete is not None and discrete.dtype != np.int32:
discrete = discrete.astype(np.int32, copy=False)
self._discrete = discrete
self._continuous: Optional[np.ndarray] = None
self._discrete: Optional[np.ndarray] = None
if continuous is not None:
self.add_continuous(continuous)
if discrete is not None:
self.add_discrete(discrete)
@property
def continuous(self) -> np.ndarray:

def discrete(self) -> np.ndarray:
return self._discrete
def add_continuous(self, continuous: np.ndarray) -> None:
if continuous.dtype != np.float32:
continuous = continuous.astype(np.float32, copy=False)
if self._discrete is None:
_discrete_dtype = self.get_discrete_dtype()
self._discrete = np.zeros((continuous.shape[0], 0), dtype=_discrete_dtype)
self._continuous = continuous
def add_discrete(self, discrete: np.ndarray) -> None:
_discrete_dtype = self.get_discrete_dtype()
if discrete.dtype != _discrete_dtype:
discrete = discrete.astype(np.int32, copy=False)
if self._continuous is None:
self._continuous = np.zeros((discrete.shape[0], 0), dtype=np.float32)
self._discrete = discrete
@abstractmethod
def get_discrete_dtype(self) -> np.dtype:
pass
class ActionTuple(_ActionTupleBase):
"""
An object whose fields correspond to actions of different types.
Continuous and discrete actions are numpy arrays of type float32 and
int32, respectively and are type checked on construction.
Dimensions are of (n_agents, continuous_size) and (n_agents, discrete_size),
respectively. Note, this also holds when continuous or discrete size is
zero.
"""
def get_discrete_dtype(self) -> np.dtype:
"""
The dtype of a discrete action.
"""
return np.int32
class ActionSpec(NamedTuple):
"""

for a number of agents.
:param n_agents: The number of agents that will have actions generated
"""
continuous = np.zeros((n_agents, self.continuous_size), dtype=np.float32)
discrete = np.zeros((n_agents, self.discrete_size), dtype=np.int32)
return ActionTuple(continuous, discrete)
_continuous = np.zeros((n_agents, self.continuous_size), dtype=np.float32)
_discrete = np.zeros((n_agents, self.discrete_size), dtype=np.int32)
return ActionTuple(continuous=_continuous, discrete=_discrete)
def random_action(self, n_agents: int) -> ActionTuple:
"""

"""
continuous = np.random.uniform(
_continuous = np.random.uniform(
discrete = np.zeros((n_agents, self.discrete_size), dtype=np.int32)
_discrete = np.zeros((n_agents, self.discrete_size), dtype=np.int32)
discrete = np.column_stack(
_discrete = np.column_stack(
[
np.random.randint(
0,

for i in range(self.discrete_size)
]
)
return ActionTuple(continuous, discrete)
return ActionTuple(continuous=_continuous, discrete=_discrete)
def _validate_action(
self, actions: ActionTuple, n_agents: int, name: str

for the correct number of agents and ensures the type.
"""
_expected_shape = (n_agents, self.continuous_size)
if self.continuous_size > 0 and actions.continuous.shape != _expected_shape:
if actions.continuous.shape != _expected_shape:
raise UnityActionException(
f"The behavior {name} needs a continuous input of dimension "
f"{_expected_shape} for (<number of agents>, <action size>) but "

if self.discrete_size > 0 and actions.discrete.shape != _expected_shape:
if actions.discrete.shape != _expected_shape:
raise UnityActionException(
f"The behavior {name} needs a discrete input of dimension "
f"{_expected_shape} for (<number of agents>, <action size>) but "

26
ml-agents/mlagents/trainers/agent_processor.py


from typing import List, Dict, TypeVar, Generic, Tuple, Any, Union
from collections import defaultdict, Counter
import queue
import numpy as np
ActionTuple,
DecisionSteps,
DecisionStep,
TerminalSteps,

from mlagents.trainers.trajectory import Trajectory, AgentExperience
from mlagents.trainers.policy import Policy
from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.behavior_id_utils import get_global_agent_id

done = terminated # Since this is an ongoing step
interrupted = step.interrupted if terminated else False
# Add the outputs of the last eval
action_dict = stored_take_action_outputs["action"]
action: Dict[str, np.ndarray] = {}
for act_type, act_array in action_dict.items():
action[act_type] = act_array[idx]
stored_actions = stored_take_action_outputs["action"]
action_tuple = ActionTuple(
continuous=stored_actions.continuous[idx],
discrete=stored_actions.discrete[idx],
)
action_probs_dict = stored_take_action_outputs["log_probs"]
action_probs: Dict[str, np.ndarray] = {}
for prob_type, prob_array in action_probs_dict.items():
action_probs[prob_type] = prob_array[idx]
stored_action_probs = stored_take_action_outputs["log_probs"]
log_probs_tuple = LogProbsTuple(
continuous=stored_action_probs.continuous[idx],
discrete=stored_action_probs.discrete[idx],
)
action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_id])[0, :]
experience = AgentExperience(

action=action,
action_probs=action_probs,
action=action_tuple,
action_probs=log_probs_tuple,
action_pre=action_pre,
action_mask=action_mask,
prev_action=prev_action,

12
ml-agents/mlagents/trainers/env_manager.py


from abc import ABC, abstractmethod
import numpy as np
from typing import List, Dict, NamedTuple, Iterable, Tuple
from mlagents_envs.base_env import (

BehaviorName,
ActionTuple,
)
from mlagents_envs.side_channel.stats_side_channel import EnvironmentStats

step_info.environment_stats, step_info.worker_id
)
return len(step_infos)
@staticmethod
def action_tuple_from_numpy_dict(action_dict: Dict[str, np.ndarray]) -> ActionTuple:
continuous: np.ndarray = None
discrete: np.ndarray = None
if "continuous_action" in action_dict:
continuous = action_dict["continuous_action"]
if "discrete_action" in action_dict:
discrete = action_dict["discrete_action"]
return ActionTuple(continuous, discrete)

12
ml-agents/mlagents/trainers/policy/policy.py


from typing import Dict, List, Optional
import numpy as np
from mlagents_envs.base_env import DecisionSteps
from mlagents_envs.base_env import ActionTuple, BehaviorSpec, DecisionSteps
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.settings import TrainerSettings, NetworkSettings

condition_sigma_on_obs: bool = True,
):
self.behavior_spec = behavior_spec
self.action_spec = behavior_spec.action_spec
self.trainer_settings = trainer_settings
self.network_settings: NetworkSettings = trainer_settings.network_settings
self.seed = seed

)
def save_previous_action(
self, agent_ids: List[str], action_dict: Dict[str, np.ndarray]
self, agent_ids: List[str], action_tuple: ActionTuple
if action_dict is None or "discrete_action" not in action_dict:
return
self.previous_action_dict[agent_id] = action_dict["discrete_action"][
index, :
]
self.previous_action_dict[agent_id] = action_tuple.discrete[index, :]
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray:
action_matrix = self.make_empty_previous_action(len(agent_ids))

22
ml-agents/mlagents/trainers/policy/tf_policy.py


from mlagents.tf_utils import tf
from mlagents import tf_utils
from mlagents_envs.exception import UnityException
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
from mlagents_envs.base_env import DecisionSteps
from mlagents_envs.base_env import DecisionSteps, ActionTuple, BehaviorSpec
from mlagents.trainers.tf.models import ModelUtils
from mlagents.trainers.settings import TrainerSettings, EncoderType
from mlagents.trainers import __version__

reparameterize,
condition_sigma_on_obs,
)
if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0:
if (
self.behavior_spec.action_spec.continuous_size > 0
and self.behavior_spec.action_spec.discrete_size > 0
):
raise UnityPolicyException(
"TensorFlow does not support mixed action spaces. Please run with the Torch framework."
)

self.save_memories(global_agent_ids, run_out.get("memory_out"))
# For Compatibility with buffer changes for hybrid action support
if "log_probs" in run_out:
run_out["log_probs"] = {"action_probs": run_out["log_probs"]}
log_probs_tuple = LogProbsTuple()
if self.behavior_spec.action_spec.is_continuous():
log_probs_tuple.add_continuous(run_out["log_probs"])
else:
log_probs_tuple.add_discrete(run_out["log_probs"])
run_out["log_probs"] = log_probs_tuple
action_tuple = ActionTuple()
run_out["action"] = {"continuous_action": run_out["action"]}
action_tuple.add_continuous(run_out["action"])
run_out["action"] = {"discrete_action": run_out["action"]}
action_tuple.add_discrete(run_out["action"])
run_out["action"] = action_tuple
return ActionInfo(
action=run_out.get("action"),
value=run_out.get("value"),

10
ml-agents/mlagents/trainers/policy/torch_policy.py


) -> Tuple[SplitObservations, np.ndarray]:
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)
mask = None
if self.action_spec.discrete_size > 0:
if self.behavior_spec.action_spec.discrete_size > 0:
mask = torch.ones([len(decision_requests), np.sum(self.act_size)])
if decision_requests.action_mask is not None:
mask = torch.as_tensor(

action, log_probs, entropy, memories = self.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories
)
action_dict = action.to_numpy_dict()
run_out["action"] = action_dict
action_tuple = action.to_action_tuple()
run_out["action"] = action_tuple
action_dict["continuous_action"] if self.use_continuous_act else None
action_tuple.continuous if self.use_continuous_act else None
run_out["log_probs"] = log_probs.to_numpy_dict()
run_out["log_probs"] = log_probs.to_log_probs_tuple()
run_out["entropy"] = ModelUtils.to_numpy(entropy)
run_out["learning_rate"] = 0.0
if self.use_recurrent:

6
ml-agents/mlagents/trainers/ppo/optimizer_tf.py


self.policy.sequence_length_ph: self.policy.sequence_length,
self.policy.mask_input: mini_batch["masks"] * burn_in_mask,
self.advantage: mini_batch["advantages"],
self.all_old_log_probs: mini_batch["action_probs"],
if self.policy.use_continuous_act: # For hybrid action buffer support
feed_dict[self.all_old_log_probs] = mini_batch["continuous_log_probs"]
else:
feed_dict[self.all_old_log_probs] = mini_batch["discrete_log_probs"]
if self.policy.output_pre is not None and "actions_pre" in mini_batch:
feed_dict[self.policy.output_pre] = mini_batch["actions_pre"]

3
ml-agents/mlagents/trainers/simple_env_manager.py


self.previous_all_action_info = all_action_info
for brain_name, action_info in all_action_info.items():
_action = EnvManager.action_tuple_from_numpy_dict(action_info.action)
self.env.set_actions(brain_name, _action)
self.env.set_actions(brain_name, action_info.action)
self.env.step()
all_step_result = self._generate_all_results()

7
ml-agents/mlagents/trainers/subprocess_env_manager.py


if req.cmd == EnvironmentCommand.STEP:
all_action_info = req.payload
for brain_name, action_info in all_action_info.items():
if len(action_info.action) != 0:
_action = EnvManager.action_tuple_from_numpy_dict(
action_info.action
)
env.set_actions(brain_name, _action)
if len(action_info.agent_ids) > 0:
env.set_actions(brain_name, action_info.action)
env.step()
all_step_result = _generate_all_results()
# The timers in this process are independent from all the processes and the "main" process

20
ml-agents/mlagents/trainers/tests/mock_brain.py


import numpy as np
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
from mlagents.trainers.trajectory import Trajectory, AgentExperience
from mlagents_envs.base_env import (
DecisionSteps,

ActionTuple,
)

steps_list = []
action_size = action_spec.discrete_size + action_spec.continuous_size
action_probs = {
"action_probs": np.ones(
int(np.sum(action_spec.discrete_branches) + action_spec.continuous_size),
dtype=np.float32,
)
}
for _i in range(length - 1):
obs = []
for _shape in observation_shapes:

if action_spec.is_continuous():
action = {"continuous_action": np.zeros(action_size, dtype=np.float32)}
else:
action = {"discrete_action": np.zeros(action_size, dtype=np.float32)}
action = ActionTuple(
continuous=np.zeros(action_spec.continuous_size, dtype=np.float32),
discrete=np.zeros(action_spec.discrete_size, dtype=np.int32),
)
action_probs = LogProbsTuple(
continuous=np.ones(action_spec.continuous_size, dtype=np.float32),
discrete=np.ones(action_spec.discrete_size, dtype=np.float32),
)
action_pre = np.zeros(action_size, dtype=np.float32)
action_mask = (
[

4
ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py


@pytest.mark.parametrize("action_sizes", [(0, 2), (2, 0)])
def test_2d_sac(action_sizes):
env = SimpleEnvironment(
[BRAIN_NAME], action_sizes=action_sizes, action_size=2, step_size=0.8
)
env = SimpleEnvironment([BRAIN_NAME], action_sizes=action_sizes, step_size=0.8)
new_hyperparams = attr.evolve(SAC_TF_CONFIG.hyperparameters, buffer_init_steps=2000)
config = attr.evolve(
SAC_TF_CONFIG,

2
ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py


behavior_spec = basic_behavior_spec()
policy = FakePolicy(test_seed, behavior_spec, TrainerSettings(), "output")
policy_eval_out = {
"action": {"continuous_action": np.array([1.0], dtype=np.float32)},
"action": np.array([[1.0]], dtype=np.float32),
"memory_out": np.array([[2.5]], dtype=np.float32),
"value": np.array([1.1], dtype=np.float32),
}

27
ml-agents/mlagents/trainers/tests/test_agent_processor.py


AgentManagerQueue,
)
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
from mlagents_envs.base_env import ActionSpec
from mlagents_envs.base_env import ActionSpec, ActionTuple
def create_mock_policy():

)
fake_action_outputs = {
"action": {"continuous_action": [0.1, 0.1]},
"action": ActionTuple(continuous=np.array([[0.1], [0.1]])),
"log_probs": {"continuous_log_probs": [0.1, 0.1]},
"log_probs": LogProbsTuple(continuous=np.array([[0.1], [0.1]])),
}
mock_decision_steps, mock_terminal_steps = mb.create_mock_steps(
num_agents=2,

fake_action_info = ActionInfo(
action={"continuous_action": [0.1, 0.1]},
action=ActionTuple(continuous=np.array([[0.1], [0.1]])),
value=[0.1, 0.1],
outputs=fake_action_outputs,
agent_ids=mock_decision_steps.agent_id,

max_trajectory_length=5,
stats_reporter=StatsReporter("testcat"),
)
"action": {"continuous_action": [0.1]},
"action": ActionTuple(continuous=np.array([[0.1]])),
"log_probs": {"continuous_log_probs": [0.1]},
"log_probs": LogProbsTuple(continuous=np.array([[0.1]])),
mock_decision_step, mock_terminal_step = mb.create_mock_steps(
num_agents=1,
observation_shapes=[(8,)],

done=True,
)
fake_action_info = ActionInfo(
action={"continuous_action": [0.1]},
action=ActionTuple(continuous=np.array([[0.1]])),
value=[0.1],
outputs=fake_action_outputs,
agent_ids=mock_decision_step.agent_id,

mock_decision_step, mock_terminal_step, _ep, fake_action_info
)
add_calls.append(
mock.call([get_global_agent_id(_ep, 0)], {"continuous_action": [0.1]})
mock.call([get_global_agent_id(_ep, 0)], fake_action_outputs["action"])
)
processor.add_experiences(
mock_done_decision_step, mock_done_terminal_step, _ep, fake_action_info

max_trajectory_length=5,
stats_reporter=StatsReporter("testcat"),
)
"action": {"continuous_action": [0.1]},
"action": ActionTuple(continuous=np.array([[0.1]])),
"log_probs": {"continuous_log_probs": [0.1]},
"log_probs": LogProbsTuple(continuous=np.array([[0.1]])),
mock_decision_step, mock_terminal_step = mb.create_mock_steps(
num_agents=1,
observation_shapes=[(8,)],

action={"continuous_action": [0.1]},
action=ActionTuple(continuous=np.array([[0.1]])),
value=[0.1],
outputs=fake_action_outputs,
agent_ids=mock_decision_step.agent_id,

4
ml-agents/mlagents/trainers/tests/test_trajectory.py


"done",
"actions_pre",
"continuous_action",
"action_probs",
"discrete_action",
"continuous_log_probs",
"discrete_log_probs",
"action_mask",
"prev_action",
"environment_rewards",

2
ml-agents/mlagents/trainers/tests/torch/test_distributions.py


optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
for _ in range(50):
dist_inst = gauss_dist(sample_embedding)[0]
dist_inst = gauss_dist(sample_embedding)
if tanh_squash:
assert isinstance(dist_inst, TanhGaussianDistInstance)
else:

82
ml-agents/mlagents/trainers/tests/torch/test_hybrid.py


from mlagents.trainers.tests.simple_test_envs import (
SimpleEnvironment,
MemoryEnvironment,
RecordEnvironment,
from mlagents.trainers.demo_loader import write_demo
from mlagents.trainers.settings import (
NetworkSettings,
SelfPlaySettings,
BehavioralCloningSettings,
GAILSettings,
RewardSignalType,
EncoderType,
FrameworkType,
)
from mlagents_envs.communicator_objects.demonstration_meta_pb2 import (
DemonstrationMetaProto,
)
from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
from mlagents_envs.communicator_objects.space_type_pb2 import discrete, continuous
from mlagents.trainers.settings import NetworkSettings, FrameworkType
from mlagents.trainers.tests.check_env_trains import (
check_environment_trains,
default_reward_processor,
)
from mlagents.trainers.tests.check_env_trains import check_environment_trains
BRAIN_NAME = "1D"

def test_hybrid_ppo():
env = SimpleEnvironment(
[BRAIN_NAME], continuous_action_size=1, discrete_action_size=1
)
config = attr.evolve(PPO_TORCH_CONFIG)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=1.0)
def test_conthybrid_ppo():
env = SimpleEnvironment(
[BRAIN_NAME], continuous_action_size=1, discrete_action_size=0
)
config = attr.evolve(PPO_TORCH_CONFIG)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=1.0)
def test_dischybrid_ppo():
env = SimpleEnvironment(
[BRAIN_NAME], continuous_action_size=0, discrete_action_size=1
)
env = SimpleEnvironment([BRAIN_NAME], action_sizes=(1, 1))
config = attr.evolve(PPO_TORCH_CONFIG)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=1.0)

env = SimpleEnvironment(
[BRAIN_NAME],
num_visual=num_visual,
num_vector=0,
continuous_action_size=1,
discrete_action_size=1,
[BRAIN_NAME], num_visual=num_visual, num_vector=0, action_sizes=(1, 1)
)
new_hyperparams = attr.evolve(
PPO_TORCH_CONFIG.hyperparameters, learning_rate=3.0e-4

def test_recurrent_ppo():
env = MemoryEnvironment(
[BRAIN_NAME], continuous_action_size=1, discrete_action_size=1
)
env = MemoryEnvironment([BRAIN_NAME], action_sizes=(1, 1))
new_network_settings = attr.evolve(
PPO_TORCH_CONFIG.network_settings,
memory=NetworkSettings.MemorySettings(memory_size=16),

PPO_TORCH_CONFIG,
hyperparameters=new_hyperparams,
network_settings=new_network_settings,
max_steps=100000,
max_steps=10000,
# def test_3cdhybrid_ppo():
# env = SimpleEnvironment(
# [BRAIN_NAME], continuous_action_size=2, discrete_action_size=1, step_size=0.8
# )
# new_hyperparams = attr.evolve(
# PPO_TORCH_CONFIG.hyperparameters, batch_size=128, buffer_size=1280, beta=0.01
# )
# config = attr.evolve(
# PPO_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=10000
# )
# check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=1.0)
#
#
# def test_3ddhybrid_ppo():
# env = SimpleEnvironment(
# [BRAIN_NAME], continuous_action_size=1, discrete_action_size=2, step_size=0.8
# )
# new_hyperparams = attr.evolve(
# PPO_TORCH_CONFIG.hyperparameters, batch_size=128, buffer_size=1280, beta=0.01
# )
# config = attr.evolve(
# PPO_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=10000
# )
# check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=1.0)

10
ml-agents/mlagents/trainers/tests/torch/test_policy.py


run_out = policy.evaluate(decision_step, list(decision_step.agent_id))
if discrete:
run_out["action"]["discrete_action"].shape == (
NUM_AGENTS,
len(DISCRETE_ACTION_SPACE),
)
run_out["action"].discrete.shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE))
assert run_out["action"]["continuous_action"].shape == (
NUM_AGENTS,
VECTOR_ACTION_SPACE,
)
assert run_out["action"].continuous.shape == (NUM_AGENTS, VECTOR_ACTION_SPACE)
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])

11
ml-agents/mlagents/trainers/tests/torch/test_ppo.py


update_buffer["extrinsic_returns"] = update_buffer["environment_rewards"]
update_buffer["extrinsic_value_estimates"] = update_buffer["environment_rewards"]
# NOTE: In TensorFlow, the log_probs are saved as one for every discrete action, whereas
# in PyTorch it is saved as the total probability per branch. So we need to modify the
# log prob in the fake buffer here.
if discrete:
update_buffer["discrete_log_probs"] = np.ones_like(
update_buffer["discrete_action"]
)
else:
update_buffer["continuous_log_probs"] = np.ones_like(
update_buffer["continuous_action"]
)
return_stats = optimizer.update(
update_buffer,
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length,

4
ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py


def test_visual_advanced_ppo(vis_encode_type, num_visual):
env = SimpleEnvironment(
[BRAIN_NAME],
action_sizes=True,
action_sizes=(0, 1),
num_visual=num_visual,
num_vector=0,
step_size=0.5,

def test_visual_advanced_sac(vis_encode_type, num_visual):
env = SimpleEnvironment(
[BRAIN_NAME],
action_sizes=True,
action_sizes=(0, 1),
num_visual=num_visual,
num_vector=0,
step_size=0.5,

4
ml-agents/mlagents/trainers/tests/torch/test_utils.py


from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.torch.distributions import (
CategoricalDistInstance,
GaussianDistInstance,
)
def test_min_visual_size():

14
ml-agents/mlagents/trainers/torch/action_flattener.py


class ActionFlattener:
def __init__(self, action_spec: ActionSpec):
"""
A torch module that creates the flattened form of an AgentAction object.
The flattened form is the continuous action concatenated with the
concatenated one hot encodings of the discrete actions.
:param action_spec: An ActionSpec that describes the action space dimensions
"""
"""
The flattened size is the continuous size plus the sum of the branch sizes
since discrete actions are encoded as one hots.
"""
"""
Returns a tensor corresponding the flattened action
:param action: An AgentAction object
"""
action_list: List[torch.Tensor] = []
if self._specs.continuous_size > 0:
action_list.append(action.continuous_tensor)

44
ml-agents/mlagents/trainers/torch/action_log_probs.py


import numpy as np
from mlagents.trainers.torch.utils import ModelUtils
from mlagents_envs.base_env import _ActionTupleBase
class LogProbsTuple(_ActionTupleBase):
"""
An object whose fields correspond to the log probs of actions of different types.
Continuous and discrete are numpy arrays
Dimensions are of (n_agents, continuous_size) and (n_agents, discrete_size),
respectively. Note, this also holds when continuous or discrete size is
zero.
"""
def get_discrete_dtype(self) -> np.dtype:
"""
The dtype of a discrete log probability.
"""
return np.float32
class ActionLogProbs(NamedTuple):

"""
return torch.cat(self.all_discrete_list, dim=1)
def to_numpy_dict(self) -> Dict[str, np.ndarray]:
def to_log_probs_tuple(self) -> LogProbsTuple:
Returns a Dict of np arrays with an entry correspinding to the continuous log probs
and an entry corresponding to the discrete log probs. "continuous_log_probs" and
"discrete_log_probs" are added to the agents buffer individually to maintain a flat buffer.
Returns a LogProbsTuple. Only adds if tensor is not None. Otherwise,
LogProbsTuple uses a default.
array_dict: Dict[str, np.ndarray] = {}
log_probs_tuple = LogProbsTuple()
array_dict["continuous_log_probs"] = ModelUtils.to_numpy(
self.continuous_tensor
)
continuous = ModelUtils.to_numpy(self.continuous_tensor)
log_probs_tuple.add_continuous(continuous)
array_dict["discrete_log_probs"] = ModelUtils.to_numpy(self.discrete_tensor)
return array_dict
discrete = ModelUtils.to_numpy(self.discrete_tensor)
log_probs_tuple.add_discrete(discrete)
return log_probs_tuple
def _to_tensor_list(self) -> List[torch.Tensor]:
"""

continuous = ModelUtils.list_to_tensor(buff["continuous_log_probs"])
if "discrete_log_probs" in buff:
discrete_tensor = ModelUtils.list_to_tensor(buff["discrete_log_probs"])
discrete = [
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
# This will keep discrete_list = None which enables flatten()
if discrete_tensor.shape[1] > 0:
discrete = [
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return ActionLogProbs(continuous, discrete, None)

84
ml-agents/mlagents/trainers/torch/action_model.py


class DistInstances(NamedTuple):
continuous: DistInstance
discrete: List[DiscreteDistInstance]
"""
A NamedTuple with fields corresponding the the DistInstance objects
output by continuous and discrete distributions, respectively. Discrete distributions
output a list of DistInstance objects whereas continuous distributions output a single
DistInstance object.
"""
continuous: Optional[DistInstance]
discrete: Optional[List[DiscreteDistInstance]]
class ActionModel(nn.Module):

conditional_sigma: bool = False,
tanh_squash: bool = False,
):
"""
A torch module that represents the action space of a policy. The ActionModel may contain
a continuous distribution, a discrete distribution or both where construction depends on
the action_spec. The ActionModel uses the encoded input of the network body to parameterize
these distributions. The forward method of this module outputs the action, log probs,
and entropies given the encoding from the network body.
:params hidden_size: Size of the input to the ActionModel.
:params action_spec: The ActionSpec defining the action space dimensions and distributions.
:params conditional_sigma: Whether or not the std of a Gaussian is conditioned on state.
:params tanh_squash: Whether to squash the output of a Gaussian with the tanh function.
"""
super().__init__()
self.encoding_size = hidden_size
self.action_spec = action_spec

def _sample_action(self, dists: DistInstances) -> AgentAction:
"""
Samples actions from a DistInstances tuple
:params dists: The DistInstances tuple
:return: An AgentAction corresponding to the actions sampled from the DistInstances
if self.action_spec.continuous_size > 0:
# This checks None because mypy complains otherwise
if dists.continuous is not None:
if self.action_spec.discrete_size > 0:
if dists.discrete is not None:
discrete_action = []
for discrete_dist in dists.discrete:
discrete_action.append(discrete_dist.sample())

"""
Creates a DistInstances tuple using the continuous and discrete distributions
:params inputs: The encoding from the network body
:params masks: Action masks for discrete actions
:return: A DistInstances tuple
"""
if self.action_spec.continuous_size > 0:
continuous_dist = self._continuous_distribution(inputs, masks)
if self.action_spec.discrete_size > 0:
# This checks None because mypy complains otherwise
if self._continuous_distribution is not None:
continuous_dist = self._continuous_distribution(inputs)
if self._discrete_distribution is not None:
discrete_dist = self._discrete_distribution(inputs, masks)
return DistInstances(continuous_dist, discrete_dist)

"""
Computes the log probabilites of the actions given distributions and entropies of
the given distributions.
:params actions: The AgentAction
:params dists: The DistInstances tuple
:return: An ActionLogProbs tuple and a torch tensor of the distribution entropies.
"""
if self.action_spec.continuous_size > 0:
# This checks None because mypy complains otherwise
if dists.continuous is not None:
if self.action_spec.discrete_size > 0:
if dists.discrete is not None:
actions.discrete_list, dists.discrete
actions.discrete_list, dists.discrete # type: ignore
):
discrete_log_prob = discrete_dist.log_prob(discrete_action)
entropies_list.append(discrete_dist.entropy())

def evaluate(
self, inputs: torch.Tensor, masks: torch.Tensor, actions: AgentAction
) -> Tuple[ActionLogProbs, torch.Tensor]:
"""
Given actions and encoding from the network body, gets the distributions and
computes the log probabilites and entropies.
:params inputs: The encoding from the network body
:params masks: Action masks for discrete actions
:params actions: The AgentAction
:return: An ActionLogProbs tuple and a torch tensor of the distribution entropies.
"""
dists = self._get_dists(inputs, masks)
log_probs, entropies = self._get_probs_and_entropy(actions, dists)
# Use the sum of entropy across actions, not the mean

def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
"""
Gets the tensors corresponding to the output of the policy network to be used for
inference. Called by the Actor's forward call.
:params inputs: The encoding from the network body
:params masks: Action masks for discrete actions
:return: A tuple of torch tensors corresponding to the inference output
"""
if self.action_spec.continuous_size > 0:
if self.action_spec.continuous_size > 0 and dists.continuous is not None:
if self.action_spec.discrete_size > 0:
if self.action_spec.discrete_size > 0 and dists.discrete is not None:
discrete_out = [
discrete_dist.exported_model_output()
for discrete_dist in dists.discrete

def forward(
self, inputs: torch.Tensor, masks: torch.Tensor
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor]:
"""
The forward method of this module. Outputs the action, log probs,
and entropies given the encoding from the network body.
:params inputs: The encoding from the network body
:params masks: Action masks for discrete actions
:return: Given the input, an AgentAction of the actions generated by the policy and the corresponding
ActionLogProbs and entropies.
"""
dists = self._get_dists(inputs, masks)
actions = self._sample_action(dists)
log_probs, entropies = self._get_probs_and_entropy(actions, dists)

21
ml-agents/mlagents/trainers/torch/agent_action.py


import numpy as np
from mlagents.trainers.torch.utils import ModelUtils
from mlagents_envs.base_env import ActionTuple
class AgentAction(NamedTuple):

"""
return torch.stack(self.discrete_list, dim=-1)
def to_numpy_dict(self) -> Dict[str, np.ndarray]:
def to_action_tuple(self) -> ActionTuple:
Returns a Dict of np arrays with an entry correspinding to the continuous action
and an entry corresponding to the discrete action. "continuous_action" and
"discrete_action" are added to the agents buffer individually to maintain a flat buffer.
Returns an ActionTuple
array_dict: Dict[str, np.ndarray] = {}
action_tuple = ActionTuple()
array_dict["continuous_action"] = ModelUtils.to_numpy(
self.continuous_tensor
)
continuous = ModelUtils.to_numpy(self.continuous_tensor)
action_tuple.add_continuous(continuous)
array_dict["discrete_action"] = ModelUtils.to_numpy(
self.discrete_tensor[:, 0, :]
)
return array_dict
discrete = ModelUtils.to_numpy(self.discrete_tensor[:, 0, :])
action_tuple.add_discrete(discrete)
return action_tuple
@staticmethod
def from_dict(buff: Dict[str, np.ndarray]) -> "AgentAction":

12
ml-agents/mlagents/trainers/torch/components/bc/module.py


expert_actions: torch.Tensor,
) -> torch.Tensor:
bc_loss = 0
if self.policy.action_spec.continuous_size > 0:
if self.policy.behavior_spec.action_spec.continuous_size > 0:
if self.policy.action_spec.discrete_size > 0:
if self.policy.behavior_spec.action_spec.discrete_size > 0:
self.policy.action_spec.discrete_branches,
self.policy.behavior_spec.action_spec.discrete_branches,
log_prob_branches = ModelUtils.break_into_branches(
log_probs.all_discrete_tensor,
self.policy.behavior_spec.action_spec.discrete_branches,

vec_obs = [ModelUtils.list_to_tensor(mini_batch_demo["vector_obs"])]
act_masks = None
expert_actions = AgentAction.from_dict(mini_batch_demo)
if self.policy.action_spec.discrete_size > 0:
if self.policy.behavior_spec.action_spec.discrete_size > 0:
act_masks = ModelUtils.list_to_tensor(
np.ones(
(

else:
vis_obs = []
selected_actions, log_probs, _, _, _ = self.policy.sample_actions(
selected_actions, log_probs, _, _ = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

2
ml-agents/mlagents/trainers/torch/distributions.py


torch.zeros(1, num_outputs, requires_grad=True)
)
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]:
def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
mu = self.mu(inputs)
if self.conditional_sigma:
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)

25
ml-agents/mlagents/trainers/trajectory.py


from typing import List, NamedTuple, Dict
from typing import List, NamedTuple
from mlagents_envs.base_env import ActionTuple
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
class AgentExperience(NamedTuple):

action: Dict[str, np.ndarray]
action_probs: Dict[str, np.ndarray]
action: ActionTuple
action_probs: LogProbsTuple
action_pre: np.ndarray # TODO: Remove this
action_mask: np.ndarray
prev_action: np.ndarray

agent_buffer_trajectory["actions_pre"].append(exp.action_pre)
# Adds the log prob and action of continuous/discrete separately
for act_type, act_array in exp.action.items():
agent_buffer_trajectory[act_type].append(act_array)
for log_type, log_array in exp.action_probs.items():
agent_buffer_trajectory[log_type].append(log_array)
agent_buffer_trajectory["continuous_action"].append(exp.action.continuous)
agent_buffer_trajectory["discrete_action"].append(exp.action.discrete)
agent_buffer_trajectory["continuous_log_probs"].append(
exp.action_probs.continuous
)
agent_buffer_trajectory["discrete_log_probs"].append(
exp.action_probs.discrete
)
# Store action masks if necessary. Note that 1 means active, while
# in AgentExperience False means active.

# This should never be needed unless the environment somehow doesn't supply the
# action mask in a discrete space.
if "discrete_action" in exp.action:
action_shape = exp.action["discrete_action"].shape
else:
action_shape = exp.action["continuous_action"].shape
action_shape = exp.action.discrete.shape
agent_buffer_trajectory["action_mask"].append(
np.ones(action_shape, dtype=np.float32), padding_value=1
)

1
utils/make_readme_table.py


ReleaseInfo("release_6", "1.3.0", "0.19.0", "August 12, 2020"),
ReleaseInfo("release_7", "1.4.0", "0.20.0", "September 16, 2020"),
ReleaseInfo("release_8", "1.5.0", "0.21.0", "October 14, 2020"),
ReleaseInfo("release_9", "1.5.0", "0.21.1", "November 4, 2020"),
]
MAX_DAYS = 150 # do not print releases older than this many days

8
Project/Assets/ML-Agents/Examples/Match3.meta


fileFormatVersion: 2
guid: 85094c6352d9e43c497a54fef35e4d76
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

67
com.unity.ml-agents.extensions/Documentation~/Match3.md


# Match-3 Game Support
We provide some utilities to integrate ML-Agents with Match-3 games.
## AbstractBoard class
The `AbstractBoard` is the bridge between ML-Agents and your game. It allows ML-Agents to
* ask your game what the "color" of a cell is
* ask whether the cell is a "special" piece type or not
* ask your game whether a move is allowed
* request that your game make a move
These are handled by implementing the `GetCellType()`, `IsMoveValid()`, and `MakeMove()` abstract methods.
The AbstractBoard also tracks the number of rows, columns, and potential piece types that the board can have.
#### `public abstract int GetCellType(int row, int col)`
Returns the "color" of piece at the given row and column.
This should be between 0 and NumCellTypes-1 (inclusive).
The actual order of the values doesn't matter.
#### `public abstract int GetSpecialType(int row, int col)`
Returns the special type of the piece at the given row and column.
This should be between 0 and NumSpecialTypes (inclusive).
The actual order of the values doesn't matter.
#### `public abstract bool IsMoveValid(Move m)`
Check whether the particular `Move` is valid for the game.
The actual results will depend on the rules of the game, but we provide the `SimpleIsMoveValid()` method
that handles basic match3 rules with no special or immovable pieces.
#### `public abstract bool MakeMove(Move m)`
Instruct the game to make the given move. Returns true if the move was made.
Note that during training, a move that was marked as invalid may occasionally still be
requested. If this happens, it is safe to do nothing and request another move.
## Move struct
The Move struct encapsulates a swap of two adjacent cells. You can get the number of potential moves
for a board of a given size with. `Move.NumPotentialMoves(NumRows, NumColumns)`. There are two helper
functions to create a new `Move`:
* `public static Move FromMoveIndex(int moveIndex, int maxRows, int maxCols)` can be used to
iterate over all potential moves for the board by looping from 0 to `Move.NumPotentialMoves()`
* `public static Move FromPositionAndDirection(int row, int col, Direction dir, int maxRows, int maxCols)` creates
a `Move` from a row, column, and direction (and board size).
## `Match3Sensor` and `Match3SensorComponent` classes
The `Match3Sensor` generates observations about the state using the `AbstractBoard` interface. You can
choose whether to use vector or "visual" observations; in theory, visual observations should perform
better because they are 2-dimensional like the board, but we need to experiment more on this.
A `Match3SensorComponent` generates a `Match3Sensor` at runtime, and should be added to the same GameObject
as your `Agent` implementation. You do not need to write any additional code to use them.
## `Match3Actuator` and `Match3ActuatorComponent` classes
The `Match3Actuator` converts actions from training or inference into a `Move` that is sent to` AbstractBoard.MakeMove()`
It also checks `AbstractBoard.IsMoveValid` for each potential move and uses this to set the action mask for Agent.
A `Match3ActuatorComponent` generates a `Match3Actuator` at runtime, and should be added to the same GameObject
as your `Agent` implementation. You do not need to write any additional code to use them.
# Setting up match-3 simulation
* Implement the `AbstractBoard` methods to integrate with your game.
* Give the `Agent` rewards when it does what you want it to (match multiple pieces in a row, clears pieces of a certain
type, etc).
* Add the `Agent`, `AbstractBoard` implementation, `Match3SensorComponent`, and `Match3ActuatorComponent` to the same
`GameObject`.
* Call `Agent.RequestDecision()` when you're ready for the `Agent` to make a move on the next `Academy` step. During
the next `Academy` step, the `MakeMove()` method on the board will be called.

3
com.unity.ml-agents.extensions/Runtime/Match3.meta


fileFormatVersion: 2
guid: 569f8fa2b7dd477c9b71f09e9d633832
timeCreated: 1600465975

3
com.unity.ml-agents.extensions/Tests/Editor/Match3.meta


fileFormatVersion: 2
guid: 77b0212dde404f7c8ce9aac13bd550b8
timeCreated: 1601332716

75
config/ppo/Match3.yaml


behaviors:
Match3VectorObs:
trainer_type: ppo
hyperparameters:
batch_size: 64
buffer_size: 12000
learning_rate: 0.0003
beta: 0.001
epsilon: 0.2
lambd: 0.99
num_epoch: 3
learning_rate_schedule: constant
network_settings:
normalize: true
hidden_units: 128
num_layers: 2
vis_encode_type: match3
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 5000000
time_horizon: 1000
summary_freq: 10000
threaded: true
Match3VisualObs:
trainer_type: ppo
hyperparameters:
batch_size: 64
buffer_size: 12000
learning_rate: 0.0003
beta: 0.001
epsilon: 0.2
lambd: 0.99
num_epoch: 3
learning_rate_schedule: constant
network_settings:
normalize: true
hidden_units: 128
num_layers: 2
vis_encode_type: match3
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 5000000
time_horizon: 1000
summary_freq: 10000
threaded: true
Match3SimpleHeuristic:
# Settings can be very simple since we don't care about actually training the model
trainer_type: ppo
hyperparameters:
batch_size: 64
buffer_size: 128
network_settings:
hidden_units: 4
num_layers: 1
max_steps: 5000000
summary_freq: 10000
threaded: true
Match3GreedyHeuristic:
# Settings can be very simple since we don't care about actually training the model
trainer_type: ppo
hyperparameters:
batch_size: 64
buffer_size: 128
network_settings:
hidden_units: 4
num_layers: 1
max_steps: 5000000
summary_freq: 10000
threaded: true

77
docs/images/match3.png

之前 之后
宽度: 297  |  高度: 320  |  大小: 22 KiB

8
Project/Assets/ML-Agents/Examples/Match3/Prefabs.meta


fileFormatVersion: 2
guid: 8519802844d8d4233b4c6f6758ab8322
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

174
Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!1 &3508723250470608007
GameObject:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
serializedVersion: 6
m_Component:
- component: {fileID: 3508723250470608008}
- component: {fileID: 3508723250470608010}
- component: {fileID: 3508723250470608012}
- component: {fileID: 3508723250470608011}
- component: {fileID: 3508723250470608009}
- component: {fileID: 3508723250470608013}
- component: {fileID: 3508723250470608014}
m_Layer: 0
m_Name: Match3 Agent
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!4 &3508723250470608008
Transform:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3508723250470608007}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 0, z: 0}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children: []
m_Father: {fileID: 3508723250774301920}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!114 &3508723250470608010
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3508723250470608007}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 5d1c4e0b1822b495aa52bc52839ecb30, type: 3}
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
VectorObservationSize: 0
NumStackedVectorObservations: 1
VectorActionSize:
VectorActionDescriptions: []
VectorActionSpaceType: 0
m_Model: {fileID: 11400000, guid: c34da50737a3c4a50918002b20b2b927, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0
m_BehaviorName: Match3SmartHeuristic
TeamId: 0
m_UseChildSensors: 1
m_UseChildActuators: 1
m_ObservableAttributeHandling: 0
--- !u!114 &3508723250470608012
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3508723250470608007}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: d982f0cd92214bd2b689be838fa40c44, type: 3}
m_Name:
m_EditorClassIdentifier:
agentParameters:
maxStep: 0
hasUpgradedFromAgentParameters: 1
MaxStep: 0
Board: {fileID: 0}
MoveTime: 0.25
MaxMoves: 500
UseSmartHeuristic: 1
--- !u!114 &3508723250470608011
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3508723250470608007}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: abebb7ad4a5547d7a3b04373784ff195, type: 3}
m_Name:
m_EditorClassIdentifier:
DebugEdgeIndex: -1
--- !u!114 &3508723250470608009
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3508723250470608007}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 6d852a063770348b68caa91b8e7642a5, type: 3}
m_Name:
m_EditorClassIdentifier:
Rows: 9
Columns: 8
NumCellTypes: 6
NumSpecialTypes: 2
RandomSeed: -1
BasicCellPoints: 1
SpecialCell1Points: 2
SpecialCell2Points: 3
--- !u!114 &3508723250470608013
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3508723250470608007}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3}
m_Name:
m_EditorClassIdentifier:
ForceHeuristic: 1
--- !u!114 &3508723250470608014
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3508723250470608007}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
m_Name:
m_EditorClassIdentifier:
ObservationType: 0
--- !u!1 &3508723250774301855
GameObject:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
serializedVersion: 6
m_Component:
- component: {fileID: 3508723250774301920}
m_Layer: 0
m_Name: Match3Heuristic
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!4 &3508723250774301920
Transform:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3508723250774301855}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 0, z: 0}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children:
- {fileID: 3508723250470608008}
m_Father: {fileID: 0}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}

7
Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab.meta


fileFormatVersion: 2
guid: 2fafdcd0587684641b03b11f04454f1b
PrefabImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

170
Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!1 &2118285883905619929
GameObject:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
serializedVersion: 6
m_Component:
- component: {fileID: 2118285883905619878}
m_Layer: 0
m_Name: Match3VectorObs
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!4 &2118285883905619878
Transform:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 2118285883905619929}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 0, z: 0}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children:
- {fileID: 2118285884327540686}
m_Father: {fileID: 0}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!1 &2118285884327540673
GameObject:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
serializedVersion: 6
m_Component:
- component: {fileID: 2118285884327540686}
- component: {fileID: 2118285884327540684}
- component: {fileID: 2118285884327540682}
- component: {fileID: 2118285884327540685}
- component: {fileID: 2118285884327540687}
- component: {fileID: 2118285884327540683}
- component: {fileID: 2118285884327540680}
m_Layer: 0
m_Name: Match3 Agent
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!4 &2118285884327540686
Transform:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 2118285884327540673}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 0, z: 0}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children: []
m_Father: {fileID: 2118285883905619878}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!114 &2118285884327540684
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 2118285884327540673}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 5d1c4e0b1822b495aa52bc52839ecb30, type: 3}
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
VectorObservationSize: 0
NumStackedVectorObservations: 1
VectorActionSize:
VectorActionDescriptions: []
VectorActionSpaceType: 0
m_Model: {fileID: 11400000, guid: 9e89b8e81974148d3b7213530d00589d, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0
m_BehaviorName: Match3VectorObs
TeamId: 0
m_UseChildSensors: 1
m_UseChildActuators: 1
m_ObservableAttributeHandling: 0
--- !u!114 &2118285884327540682
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 2118285884327540673}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: d982f0cd92214bd2b689be838fa40c44, type: 3}
m_Name:
m_EditorClassIdentifier:
agentParameters:
maxStep: 0
hasUpgradedFromAgentParameters: 1
MaxStep: 0
Board: {fileID: 0}
MoveTime: 0.25
MaxMoves: 500
--- !u!114 &2118285884327540685
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 2118285884327540673}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: abebb7ad4a5547d7a3b04373784ff195, type: 3}
m_Name:
m_EditorClassIdentifier:
DebugEdgeIndex: -1
--- !u!114 &2118285884327540687
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 2118285884327540673}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 6d852a063770348b68caa91b8e7642a5, type: 3}
m_Name:
m_EditorClassIdentifier:
Rows: 9
Columns: 8
NumCellTypes: 6
NumSpecialTypes: 2
RandomSeed: -1
--- !u!114 &2118285884327540683
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 2118285884327540673}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3}
m_Name:
m_EditorClassIdentifier:
ForceRandom: 0
--- !u!114 &2118285884327540680
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 2118285884327540673}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
m_Name:
m_EditorClassIdentifier:
ObservationType: 0

7
Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab.meta


fileFormatVersion: 2
guid: 6944ca02359f5427aa13c8551236a824
PrefabImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

170
Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!1 &3019509691567202678
GameObject:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
serializedVersion: 6
m_Component:
- component: {fileID: 3019509691567202569}
m_Layer: 0
m_Name: Match3VisualObs
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!4 &3019509691567202569
Transform:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3019509691567202678}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 0, z: 0}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children:
- {fileID: 3019509692332007777}
m_Father: {fileID: 0}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!1 &3019509692332007790
GameObject:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
serializedVersion: 6
m_Component:
- component: {fileID: 3019509692332007777}
- component: {fileID: 3019509692332007779}
- component: {fileID: 3019509692332007781}
- component: {fileID: 3019509692332007778}
- component: {fileID: 3019509692332007776}
- component: {fileID: 3019509692332007780}
- component: {fileID: 3019509692332007783}
m_Layer: 0
m_Name: Match3 Agent
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!4 &3019509692332007777
Transform:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3019509692332007790}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 0, z: 0}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children: []
m_Father: {fileID: 3019509691567202569}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!114 &3019509692332007779
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3019509692332007790}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 5d1c4e0b1822b495aa52bc52839ecb30, type: 3}
m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
VectorObservationSize: 0
NumStackedVectorObservations: 1
VectorActionSize:
VectorActionDescriptions: []
VectorActionSpaceType: 0
m_Model: {fileID: 11400000, guid: 48d14da88fea74d0693c691c6e3f2e34, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0
m_BehaviorName: Match3VisualObs
TeamId: 0
m_UseChildSensors: 1
m_UseChildActuators: 1
m_ObservableAttributeHandling: 0
--- !u!114 &3019509692332007781
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3019509692332007790}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: d982f0cd92214bd2b689be838fa40c44, type: 3}
m_Name:
m_EditorClassIdentifier:
agentParameters:
maxStep: 0
hasUpgradedFromAgentParameters: 1
MaxStep: 0
Board: {fileID: 0}
MoveTime: 0.25
MaxMoves: 500
--- !u!114 &3019509692332007778
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3019509692332007790}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: abebb7ad4a5547d7a3b04373784ff195, type: 3}
m_Name:
m_EditorClassIdentifier:
DebugEdgeIndex: -1
--- !u!114 &3019509692332007776
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3019509692332007790}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 6d852a063770348b68caa91b8e7642a5, type: 3}
m_Name:
m_EditorClassIdentifier:
Rows: 9
Columns: 8
NumCellTypes: 6
NumSpecialTypes: 2
RandomSeed: -1
--- !u!114 &3019509692332007780
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3019509692332007790}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3}
m_Name:
m_EditorClassIdentifier:
ForceRandom: 0
--- !u!114 &3019509692332007783
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 3019509692332007790}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
m_Name:
m_EditorClassIdentifier:
ObservationType: 2

7
Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab.meta


fileFormatVersion: 2
guid: aaa471bd5e2014848a66917476671aed
PrefabImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

8
Project/Assets/ML-Agents/Examples/Match3/Scenes.meta


fileFormatVersion: 2
guid: e033fb0df67684ebf961ed115870ff10
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

1001
Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity
文件差异内容过多而无法显示
查看文件

7
Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity.meta


fileFormatVersion: 2
guid: 2e09c5458f1494f9dad9cd6d09dff964
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

8
Project/Assets/ML-Agents/Examples/Match3/Scripts.meta


fileFormatVersion: 2
guid: be7a27f4291944d3dba4f696e1af4209
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

373
Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs


using System;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Extensions.Match3;
namespace Unity.MLAgentsExamples
{
/// <summary>
/// State of the "game" when showing all steps of the simulation. This is only used outside of training.
/// The state diagram is
///
/// | <--------------------------------------- ^
/// | |
/// v |
/// +--------+ +-------+ +-----+ +------+
/// |Find | ---> |Clear | ---> |Drop | ---> |Fill |
/// |Matches | |Matched| | | |Empty |
/// +--------+ +-------+ +-----+ +------+
///
/// | ^
/// | |
/// v |
///
/// +--------+
/// |Wait for|
/// |Move |
/// +--------+
///
/// The stats advances each "MoveTime" seconds.
/// </summary>
enum State
{
/// <summary>
/// Guard value, should never happen.
/// </summary>
Invalid = -1,
/// <summary>
/// Look for matches. If there are matches, the next state is ClearMatched, otherwise WaitForMove.
/// </summary>
FindMatches = 0,
/// <summary>
/// Remove matched cells and replace them with a placeholder value.
/// </summary>
ClearMatched = 1,
/// <summary>
/// Move cells "down" to fill empty space.
/// </summary>
Drop = 2,
/// <summary>
/// Replace empty cells with new random values.
/// </summary>
FillEmpty = 3,
/// <summary>
/// Request a move from the Agent.
/// </summary>
WaitForMove = 4,
}
public enum HeuristicQuality
{
/// <summary>
/// The heuristic will pick any valid move at random.
/// </summary>
RandomValidMove,
/// <summary>
/// The heuristic will pick the move that scores the most points.
/// This only looks at the immediate move, and doesn't consider where cells will fall.
/// </summary>
Greedy
}
public class Match3Agent : Agent
{
[HideInInspector]
public Match3Board Board;
public float MoveTime = 1.0f;
public int MaxMoves = 500;
public HeuristicQuality HeuristicQuality = HeuristicQuality.RandomValidMove;
State m_CurrentState = State.WaitForMove;
float m_TimeUntilMove;
private int m_MovesMade;
private System.Random m_Random;
private const float k_RewardMultiplier = 0.01f;
void Awake()
{
Board = GetComponent<Match3Board>();
var seed = Board.RandomSeed == -1 ? gameObject.GetInstanceID() : Board.RandomSeed + 1;
m_Random = new System.Random(seed);
}
public override void OnEpisodeBegin()
{
base.OnEpisodeBegin();
Board.InitSettled();
m_CurrentState = State.FindMatches;
m_TimeUntilMove = MoveTime;
m_MovesMade = 0;
}
private void FixedUpdate()
{
if (Academy.Instance.IsCommunicatorOn)
{
FastUpdate();
}
else
{
AnimatedUpdate();
}
// We can't use the normal MaxSteps system to decide when to end an episode,
// since different agents will make moves at different frequencies (depending on the number of
// chained moves). So track a number of moves per Agent and manually interrupt the episode.
if (m_MovesMade >= MaxMoves)
{
EpisodeInterrupted();
}
}
void FastUpdate()
{
while (true)
{
var hasMatched = Board.MarkMatchedCells();
if (!hasMatched)
{
break;
}
var pointsEarned = Board.ClearMatchedCells();
AddReward(k_RewardMultiplier * pointsEarned);
Board.DropCells();
Board.FillFromAbove();
}
while (!HasValidMoves())
{
// Shuffle the board until we have a valid move.
Board.InitSettled();
}
RequestDecision();
m_MovesMade++;
}
void AnimatedUpdate()
{
m_TimeUntilMove -= Time.deltaTime;
if (m_TimeUntilMove > 0.0f)
{
return;
}
m_TimeUntilMove = MoveTime;
var nextState = State.Invalid;
switch (m_CurrentState)
{
case State.FindMatches:
var hasMatched = Board.MarkMatchedCells();
nextState = hasMatched ? State.ClearMatched : State.WaitForMove;
if (nextState == State.WaitForMove)
{
m_MovesMade++;
}
break;
case State.ClearMatched:
var pointsEarned = Board.ClearMatchedCells();
AddReward(k_RewardMultiplier * pointsEarned);
nextState = State.Drop;
break;
case State.Drop:
Board.DropCells();
nextState = State.FillEmpty;
break;
case State.FillEmpty:
Board.FillFromAbove();
nextState = State.FindMatches;
break;
case State.WaitForMove:
while (true)
{
// Shuffle the board until we have a valid move.
bool hasMoves = HasValidMoves();
if (hasMoves)
{
break;
}
Board.InitSettled();
}
RequestDecision();
nextState = State.FindMatches;
break;
default:
throw new ArgumentOutOfRangeException();
}
m_CurrentState = nextState;
}
bool HasValidMoves()
{
foreach (var move in Board.ValidMoves())
{
return true;
}
return false;
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActions = actionsOut.DiscreteActions;
discreteActions[0] = GreedyMove();
}
int GreedyMove()
{
var pointsByType = new[] { Board.BasicCellPoints, Board.SpecialCell1Points, Board.SpecialCell2Points };
var bestMoveIndex = 0;
var bestMovePoints = -1;
var numMovesAtCurrentScore = 0;
foreach (var move in Board.ValidMoves())
{
var movePoints = HeuristicQuality == HeuristicQuality.Greedy ? EvalMovePoints(move, pointsByType) : 1;
if (movePoints < bestMovePoints)
{
// Worse, skip
continue;
}
if (movePoints > bestMovePoints)
{
// Better, keep
bestMovePoints = movePoints;
bestMoveIndex = move.MoveIndex;
numMovesAtCurrentScore = 1;
}
else
{
// Tied for best - use reservoir sampling to make sure we select from equal moves uniformly.
// See https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm
numMovesAtCurrentScore++;
var randVal = m_Random.Next(0, numMovesAtCurrentScore);
if (randVal == 0)
{
// Keep the new one
bestMoveIndex = move.MoveIndex;
}
}
}
return bestMoveIndex;
}
int EvalMovePoints(Move move, int[] pointsByType)
{
// Counts the expected points for making the move.
var moveVal = Board.GetCellType(move.Row, move.Column);
var moveSpecial = Board.GetSpecialType(move.Row, move.Column);
var (otherRow, otherCol) = move.OtherCell();
var oppositeVal = Board.GetCellType(otherRow, otherCol);
var oppositeSpecial = Board.GetSpecialType(otherRow, otherCol);
int movePoints = EvalHalfMove(
otherRow, otherCol, moveVal, moveSpecial, move.Direction, pointsByType
);
int otherPoints = EvalHalfMove(
move.Row, move.Column, oppositeVal, oppositeSpecial, move.OtherDirection(), pointsByType
);
return movePoints + otherPoints;
}
int EvalHalfMove(int newRow, int newCol, int newValue, int newSpecial, Direction incomingDirection, int[] pointsByType)
{
// This is a essentially a duplicate of AbstractBoard.CheckHalfMove but also counts the points for the move.
int matchedLeft = 0, matchedRight = 0, matchedUp = 0, matchedDown = 0;
int scoreLeft = 0, scoreRight = 0, scoreUp = 0, scoreDown = 0;
if (incomingDirection != Direction.Right)
{
for (var c = newCol - 1; c >= 0; c--)
{
if (Board.GetCellType(newRow, c) == newValue)
{
matchedLeft++;
scoreLeft += pointsByType[Board.GetSpecialType(newRow, c)];
}
else
break;
}
}
if (incomingDirection != Direction.Left)
{
for (var c = newCol + 1; c < Board.Columns; c++)
{
if (Board.GetCellType(newRow, c) == newValue)
{
matchedRight++;
scoreRight += pointsByType[Board.GetSpecialType(newRow, c)];
}
else
break;
}
}
if (incomingDirection != Direction.Down)
{
for (var r = newRow + 1; r < Board.Rows; r++)
{
if (Board.GetCellType(r, newCol) == newValue)
{
matchedUp++;
scoreUp += pointsByType[Board.GetSpecialType(r, newCol)];
}
else
break;
}
}
if (incomingDirection != Direction.Up)
{
for (var r = newRow - 1; r >= 0; r--)
{
if (Board.GetCellType(r, newCol) == newValue)
{
matchedDown++;
scoreDown += pointsByType[Board.GetSpecialType(r, newCol)];
}
else
break;
}
}
if ((matchedUp + matchedDown >= 2) || (matchedLeft + matchedRight >= 2))
{
// It's a match. Start from counting the piece being moved
var totalScore = pointsByType[newSpecial];
if (matchedUp + matchedDown >= 2)
{
totalScore += scoreUp + scoreDown;
}
if (matchedLeft + matchedRight >= 2)
{
totalScore += scoreLeft + scoreRight;
}
return totalScore;
}
return 0;
}
}
}

3
Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs.meta


fileFormatVersion: 2
guid: d982f0cd92214bd2b689be838fa40c44
timeCreated: 1598221207

272
Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs


using Unity.MLAgents.Extensions.Match3;
using UnityEngine;
namespace Unity.MLAgentsExamples
{
public class Match3Board : AbstractBoard
{
public int RandomSeed = -1;
public const int k_EmptyCell = -1;
[Tooltip("Points earned for clearing a basic cell (cube)")]
public int BasicCellPoints = 1;
[Tooltip("Points earned for clearing a special cell (sphere)")]
public int SpecialCell1Points = 2;
[Tooltip("Points earned for clearing an extra special cell (plus)")]
public int SpecialCell2Points = 3;
(int, int)[,] m_Cells;
bool[,] m_Matched;
System.Random m_Random;
void Awake()
{
m_Cells = new (int, int)[Columns, Rows];
m_Matched = new bool[Columns, Rows];
m_Random = new System.Random(RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed);
InitRandom();
}
public override bool MakeMove(Move move)
{
if (!IsMoveValid(move))
{
return false;
}
var originalValue = m_Cells[move.Column, move.Row];
var (otherRow, otherCol) = move.OtherCell();
var destinationValue = m_Cells[otherCol, otherRow];
m_Cells[move.Column, move.Row] = destinationValue;
m_Cells[otherCol, otherRow] = originalValue;
return true;
}
public override int GetCellType(int row, int col)
{
return m_Cells[col, row].Item1;
}
public override int GetSpecialType(int row, int col)
{
return m_Cells[col, row].Item2;
}
public override bool IsMoveValid(Move m)
{
if (m_Cells == null)
{
return false;
}
return SimpleIsMoveValid(m);
}
public bool MarkMatchedCells(int[,] cells = null)
{
ClearMarked();
bool madeMatch = false;
for (var i = 0; i < Rows; i++)
{
for (var j = 0; j < Columns; j++)
{
// Check vertically
var matchedRows = 0;
for (var iOffset = i; iOffset < Rows; iOffset++)
{
if (m_Cells[j, i].Item1 != m_Cells[j, iOffset].Item1)
{
break;
}
matchedRows++;
}
if (matchedRows >= 3)
{
madeMatch = true;
for (var k = 0; k < matchedRows; k++)
{
m_Matched[j, i + k] = true;
}
}
// Check vertically
var matchedCols = 0;
for (var jOffset = j; jOffset < Columns; jOffset++)
{
if (m_Cells[j, i].Item1 != m_Cells[jOffset, i].Item1)
{
break;
}
matchedCols++;
}
if (matchedCols >= 3)
{
madeMatch = true;
for (var k = 0; k < matchedCols; k++)
{
m_Matched[j + k, i] = true;
}
}
}
}
return madeMatch;
}
/// <summary>
/// Sets cells that are matched to the empty cell, and returns the score earned.
/// </summary>
/// <returns></returns>
public int ClearMatchedCells()
{
var pointsByType = new[] { BasicCellPoints, SpecialCell1Points, SpecialCell2Points };
int pointsEarned = 0;
for (var i = 0; i < Rows; i++)
{
for (var j = 0; j < Columns; j++)
{
if (m_Matched[j, i])
{
var speciaType = GetSpecialType(i, j);
pointsEarned += pointsByType[speciaType];
m_Cells[j, i] = (k_EmptyCell, 0);
}
}
}
ClearMarked(); // TODO clear here or at start of matching?
return pointsEarned;
}
public bool DropCells()
{
var madeChanges = false;
// Gravity is applied in the negative row direction
for (var j = 0; j < Columns; j++)
{
var writeIndex = 0;
for (var readIndex = 0; readIndex < Rows; readIndex++)
{
m_Cells[j, writeIndex] = m_Cells[j, readIndex];
if (m_Cells[j, readIndex].Item1 != k_EmptyCell)
{
writeIndex++;
}
}
// Fill in empties at the end
for (; writeIndex < Rows; writeIndex++)
{
madeChanges = true;
m_Cells[j, writeIndex] = (k_EmptyCell, 0);
}
}
return madeChanges;
}
public bool FillFromAbove()
{
bool madeChanges = false;
for (var i = 0; i < Rows; i++)
{
for (var j = 0; j < Columns; j++)
{
if (m_Cells[j, i].Item1 == k_EmptyCell)
{
madeChanges = true;
m_Cells[j, i] = (GetRandomCellType(), GetRandomSpecialType());
}
}
}
return madeChanges;
}
public (int, int)[,] Cells
{
get { return m_Cells; }
}
public bool[,] Matched
{
get { return m_Matched; }
}
// Initialize the board to random values.
public void InitRandom()
{
for (var i = 0; i < Rows; i++)
{
for (var j = 0; j < Columns; j++)
{
m_Cells[j, i] = (GetRandomCellType(), GetRandomSpecialType());
}
}
}
public void InitSettled()
{
InitRandom();
while (true)
{
var anyMatched = MarkMatchedCells();
if (!anyMatched)
{
return;
}
ClearMatchedCells();
DropCells();
FillFromAbove();
}
}
void ClearMarked()
{
for (var i = 0; i < Rows; i++)
{
for (var j = 0; j < Columns; j++)
{
m_Matched[j, i] = false;
}
}
}
int GetRandomCellType()
{
return m_Random.Next(0, NumCellTypes);
}
int GetRandomSpecialType()
{
// 1 in N chance to get a type-2 special
// 2 in N chance to get a type-1 special
// otherwise 0 (boring)
var N = 10;
var val = m_Random.Next(0, N);
if (val == 0)
{
return 2;
}
if (val <= 2)
{
return 1;
}
return 0;
}
}
}

11
Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs.meta


fileFormatVersion: 2
guid: 6d852a063770348b68caa91b8e7642a5
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

102
Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs


using UnityEngine;
using Unity.MLAgents.Extensions.Match3;
namespace Unity.MLAgentsExamples
{
public class Match3Drawer : MonoBehaviour
{
public int DebugMoveIndex = -1;
static Color[] s_Colors = new[]
{
Color.red,
Color.green,
Color.blue,
Color.cyan,
Color.magenta,
Color.yellow,
Color.gray,
Color.black,
};
private static Color s_EmptyColor = new Color(0.5f, 0.5f, 0.5f, .25f);
void OnDrawGizmos()
{
// TODO replace Gizmos for drawing the game state with proper GameObjects and animations.
var cubeSize = .5f;
var cubeSpacing = .75f;
var matchedWireframeSize = .5f * (cubeSize + cubeSpacing);
var board = GetComponent<Match3Board>();
if (board == null)
{
return;
}
for (var i = 0; i < board.Rows; i++)
{
for (var j = 0; j < board.Columns; j++)
{
var value = board.Cells != null ? board.GetCellType(i, j) : Match3Board.k_EmptyCell;
if (value >= 0 && value < s_Colors.Length)
{
Gizmos.color = s_Colors[value];
}
else
{
Gizmos.color = s_EmptyColor;
}
var pos = new Vector3(j, i, 0);
pos *= cubeSpacing;
var specialType = board.Cells != null ? board.GetSpecialType(i, j) : 0;
if (specialType == 2)
{
Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * new Vector3(1f, .5f, .5f));
Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * new Vector3(.5f, 1f, .5f));
Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * new Vector3(.5f, .5f, 1f));
}
else if (specialType == 1)
{
Gizmos.DrawSphere(transform.TransformPoint(pos), .5f * cubeSize);
}
else
{
Gizmos.DrawCube(transform.TransformPoint(pos), cubeSize * Vector3.one);
}
Gizmos.color = Color.yellow;
if (board.Matched != null && board.Matched[j, i])
{
Gizmos.DrawWireCube(transform.TransformPoint(pos), matchedWireframeSize * Vector3.one);
}
}
}
// Draw valid moves
foreach (var move in board.AllMoves())
{
if (DebugMoveIndex >= 0 && move.MoveIndex != DebugMoveIndex)
{
continue;
}
if (!board.IsMoveValid(move))
{
continue;
}
var (otherRow, otherCol) = move.OtherCell();
var pos = new Vector3(move.Column, move.Row, 0) * cubeSpacing;
var otherPos = new Vector3(otherCol, otherRow, 0) * cubeSpacing;
var oneQuarter = Vector3.Lerp(pos, otherPos, .25f);
var threeQuarters = Vector3.Lerp(pos, otherPos, .75f);
Gizmos.DrawLine(transform.TransformPoint(oneQuarter), transform.TransformPoint(threeQuarters));
}
}
}
}

3
Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Drawer.cs.meta


fileFormatVersion: 2
guid: abebb7ad4a5547d7a3b04373784ff195
timeCreated: 1598221188

8
Project/Assets/ML-Agents/Examples/Match3/TFModels.meta


fileFormatVersion: 2
guid: 504c8f923fdf448e795936f2900a5fd4
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

1001
Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx
文件差异内容过多而无法显示
查看文件

14
Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VectorObs.onnx.meta


fileFormatVersion: 2
guid: 9e89b8e81974148d3b7213530d00589d
ScriptedImporter:
fileIDToRecycleName:
11400000: main obj
11400002: model data
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:
script: {fileID: 11500000, guid: 683b6cb6d0a474744822c888b46772c9, type: 3}
optimizeModel: 1
forceArbitraryBatchSize: 1
treatErrorsAsWarnings: 0

1001
Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn
文件差异内容过多而无法显示
查看文件

11
Project/Assets/ML-Agents/Examples/Match3/TFModels/Match3VisualObs.nn.meta


fileFormatVersion: 2
guid: 48d14da88fea74d0693c691c6e3f2e34
ScriptedImporter:
fileIDToRecycleName:
11400000: main obj
11400002: model data
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:
script: {fileID: 11500000, guid: 19ed1486aa27d4903b34839f37b8f69f, type: 3}

233
com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs


using System;
using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Extensions.Match3
{
public abstract class AbstractBoard : MonoBehaviour
{
/// <summary>
/// Number of rows on the board
/// </summary>
public int Rows;
/// <summary>
/// Number of columns on the board
/// </summary>
public int Columns;
/// <summary>
/// Maximum number of different types of cells (colors, pieces, etc).
/// </summary>
public int NumCellTypes;
/// <summary>
/// Maximum number of special types. This can be zero, in which case
/// all cells of the same type are assumed to be equivalent.
/// </summary>
public int NumSpecialTypes;
/// <summary>
/// Returns the "color" of the piece at the given row and column.
/// This should be between 0 and NumCellTypes-1 (inclusive).
/// The actual order of the values doesn't matter.
/// </summary>
/// <param name="row"></param>
/// <param name="col"></param>
/// <returns></returns>
public abstract int GetCellType(int row, int col);
/// <summary>
/// Returns the special type of the piece at the given row and column.
/// This should be between 0 and NumSpecialTypes (inclusive).
/// The actual order of the values doesn't matter.
/// </summary>
/// <param name="row"></param>
/// <param name="col"></param>
/// <returns></returns>
public abstract int GetSpecialType(int row, int col);
/// <summary>
/// Check whether the particular Move is valid for the game.
/// The actual results will depend on the rules of the game, but we provide SimpleIsMoveValid()
/// that handles basic match3 rules with no special or immovable pieces.
/// </summary>
/// <param name="m"></param>
/// <returns></returns>
public abstract bool IsMoveValid(Move m);
/// <summary>
/// Instruct the game to make the given move. Returns true if the move was made.
/// Note that during training, a move that was marked as invalid may occasionally still be
/// requested. If this happens, it is safe to do nothing and request another move.
/// </summary>
/// <param name="m"></param>
/// <returns></returns>
public abstract bool MakeMove(Move m);
/// <summary>
/// Return the total number of moves possible for the board.
/// </summary>
/// <returns></returns>
public int NumMoves()
{
return Move.NumPotentialMoves(Rows, Columns);
}
/// <summary>
/// An optional callback for when the all moves are invalid. Ideally, the game state should
/// be changed before this happens, but this is a way to get notified if not.
/// </summary>
public Action OnNoValidMovesAction;
/// <summary>
/// Iterate through all Moves on the board.
/// </summary>
/// <returns></returns>
public IEnumerable<Move> AllMoves()
{
var currentMove = Move.FromMoveIndex(0, Rows, Columns);
for (var i = 0; i < NumMoves(); i++)
{
yield return currentMove;
currentMove.Next(Rows, Columns);
}
}
/// <summary>
/// Iterate through all valid Moves on the board.
/// </summary>
/// <returns></returns>
public IEnumerable<Move> ValidMoves()
{
var currentMove = Move.FromMoveIndex(0, Rows, Columns);
for (var i = 0; i < NumMoves(); i++)
{
if (IsMoveValid(currentMove))
{
yield return currentMove;
}
currentMove.Next(Rows, Columns);
}
}
/// <summary>
/// Iterate through all invalid Moves on the board.
/// </summary>
/// <returns></returns>
public IEnumerable<Move> InvalidMoves()
{
var currentMove = Move.FromMoveIndex(0, Rows, Columns);
for (var i = 0; i < NumMoves(); i++)
{
if (!IsMoveValid(currentMove))
{
yield return currentMove;
}
currentMove.Next(Rows, Columns);
}
}
/// <summary>
/// Returns true if swapped the cells specified by the move would result in
/// 3 or more cells of the same type in a row. This assumes that all pieces are allowed
/// to be moved; to add extra logic, incorporate it into you IsMoveValid() method.
/// </summary>
/// <param name="move"></param>
/// <returns></returns>
public bool SimpleIsMoveValid(Move move)
{
using (TimerStack.Instance.Scoped("SimpleIsMoveValid"))
{
var moveVal = GetCellType(move.Row, move.Column);
var (otherRow, otherCol) = move.OtherCell();
var oppositeVal = GetCellType(otherRow, otherCol);
// Simple check - if the values are the same, don't match
// This might not be valid for all games
{
if (moveVal == oppositeVal)
{
return false;
}
}
bool moveMatches = CheckHalfMove(otherRow, otherCol, moveVal, move.Direction);
if (moveMatches)
{
// early out
return true;
}
bool otherMatches = CheckHalfMove(move.Row, move.Column, oppositeVal, move.OtherDirection());
return otherMatches;
}
}
/// <summary>
/// Check if one of the cells that is swapped during a move matches 3 or more.
/// Since these checks are similar for each cell, we consider the Move as two "half moves".
/// </summary>
/// <param name="newRow"></param>
/// <param name="newCol"></param>
/// <param name="newValue"></param>
/// <param name="incomingDirection"></param>
/// <returns></returns>
bool CheckHalfMove(int newRow, int newCol, int newValue, Direction incomingDirection)
{
int matchedLeft = 0, matchedRight = 0, matchedUp = 0, matchedDown = 0;
if (incomingDirection != Direction.Right)
{
for (var c = newCol - 1; c >= 0; c--)
{
if (GetCellType(newRow, c) == newValue)
matchedLeft++;
else
break;
}
}
if (incomingDirection != Direction.Left)
{
for (var c = newCol + 1; c < Columns; c++)
{
if (GetCellType(newRow, c) == newValue)
matchedRight++;
else
break;
}
}
if (incomingDirection != Direction.Down)
{
for (var r = newRow + 1; r < Rows; r++)
{
if (GetCellType(r, newCol) == newValue)
matchedUp++;
else
break;
}
}
if (incomingDirection != Direction.Up)
{
for (var r = newRow - 1; r >= 0; r--)
{
if (GetCellType(r, newCol) == newValue)
matchedDown++;
else
break;
}
}
if ((matchedUp + matchedDown >= 2) || (matchedLeft + matchedRight >= 2))
{
return true;
}
return false;
}
}
}

3
com.unity.ml-agents.extensions/Runtime/Match3/AbstractBoard.cs.meta


fileFormatVersion: 2
guid: 6222defa70dc4c08aaeafd0be4e821d2
timeCreated: 1600466051

120
com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs


using System.Collections.Generic;
using Unity.MLAgents.Actuators;
using UnityEngine;
namespace Unity.MLAgents.Extensions.Match3
{
/// <summary>
/// Actuator for a Match3 game. It translates valid moves (defined by AbstractBoard.IsMoveValid())
/// in action masks, and applies the action to the board via AbstractBoard.MakeMove().
/// </summary>
public class Match3Actuator : IActuator
{
private AbstractBoard m_Board;
private ActionSpec m_ActionSpec;
private bool m_ForceHeuristic;
private System.Random m_Random;
private Agent m_Agent;
private int m_Rows;
private int m_Columns;
private int m_NumCellTypes;
/// <summary>
/// Create a Match3Actuator.
/// </summary>
/// <param name="board"></param>
/// <param name="forceHeuristic">Whether the inference action should be ignored and the Agent's Heuristic
/// should be called. This should only be used for generating comparison stats of the Heuristic.</param>
/// <param name="agent"></param>
/// <param name="name"></param>
public Match3Actuator(AbstractBoard board, bool forceHeuristic, Agent agent, string name)
{
m_Board = board;
m_Rows = board.Rows;
m_Columns = board.Columns;
m_NumCellTypes = board.NumCellTypes;
Name = name;
m_ForceHeuristic = forceHeuristic;
m_Agent = agent;
var numMoves = Move.NumPotentialMoves(m_Board.Rows, m_Board.Columns);
m_ActionSpec = ActionSpec.MakeDiscrete(numMoves);
}
/// <inheritdoc/>
public ActionSpec ActionSpec => m_ActionSpec;
/// <inheritdoc/>
public void OnActionReceived(ActionBuffers actions)
{
if (m_ForceHeuristic)
{
m_Agent.Heuristic(actions);
}
var moveIndex = actions.DiscreteActions[0];
if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes)
{
Debug.LogWarning(
$"Board shape changes since actuator initialization. This may cause unexpected results. " +
$"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " +
$"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}"
);
}
Move move = Move.FromMoveIndex(moveIndex, m_Rows, m_Columns);
m_Board.MakeMove(move);
}
/// <inheritdoc/>
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
using (TimerStack.Instance.Scoped("WriteDiscreteActionMask"))
{
actionMask.WriteMask(0, InvalidMoveIndices());
}
}
/// <inheritdoc/>
public string Name { get; }
/// <inheritdoc/>
public void ResetData()
{
}
IEnumerable<int> InvalidMoveIndices()
{
var numValidMoves = m_Board.NumMoves();
foreach (var move in m_Board.InvalidMoves())
{
numValidMoves--;
if (numValidMoves == 0)
{
// If all the moves are invalid and we mask all the actions out, this will cause an assert
// later on in IDiscreteActionMask. Instead, fire a callback to the user if they provided one,
// (or log a warning if not) and leave the last action unmasked. This isn't great, but
// an invalid move should be easier to handle than an exception..
if (m_Board.OnNoValidMovesAction != null)
{
m_Board.OnNoValidMovesAction();
}
else
{
Debug.LogWarning(
"No valid moves are available. The last action will be left unmasked, so " +
"an invalid move will be passed to AbstractBoard.MakeMove()."
);
}
// This means the last move won't be returned as an invalid index.
yield break;
}
yield return move.MoveIndex;
}
}
}
}

3
com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs.meta


fileFormatVersion: 2
guid: 9083fa4c35dc499aa5a86d8e7447c7cf
timeCreated: 1600906373

49
com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs


using Unity.MLAgents.Actuators;
using UnityEngine;
using UnityEngine.Serialization;
namespace Unity.MLAgents.Extensions.Match3
{
/// <summary>
/// Actuator component for a Match 3 game. Generates a Match3Actuator at runtime.
/// </summary>
public class Match3ActuatorComponent : ActuatorComponent
{
/// <summary>
/// Name of the generated Match3Actuator object.
/// Note that changing this at runtime does not affect how the Agent sorts the actuators.
/// </summary>
public string ActuatorName = "Match3 Actuator";
/// <summary>
/// Force using the Agent's Heuristic() method to decide the action. This should only be used in testing.
/// </summary>
[FormerlySerializedAs("ForceRandom")]
[Tooltip("Force using the Agent's Heuristic() method to decide the action. This should only be used in testing.")]
public bool ForceHeuristic = false;
/// <inheritdoc/>
public override IActuator CreateActuator()
{
var board = GetComponent<AbstractBoard>();
var agent = GetComponentInParent<Agent>();
return new Match3Actuator(board, ForceHeuristic, agent, ActuatorName);
}
/// <inheritdoc/>
public override ActionSpec ActionSpec
{
get
{
var board = GetComponent<AbstractBoard>();
if (board == null)
{
return ActionSpec.MakeContinuous(0);
}
var numMoves = Move.NumPotentialMoves(board.Rows, board.Columns);
return ActionSpec.MakeDiscrete(numMoves);
}
}
}
}

3
com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs.meta


fileFormatVersion: 2
guid: 08e4b0da54cb4d56bfcbae22dd49ab8d
timeCreated: 1600906388

297
com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs


using System.Collections.Generic;
using Unity.MLAgents.Sensors;
using UnityEngine;
namespace Unity.MLAgents.Extensions.Match3
{
/// <summary>
/// Type of observations to generate.
///
/// </summary>
public enum Match3ObservationType
{
/// <summary>
/// Generate a one-hot encoding of the cell type for each cell on the board. If there are special types,
/// these will also be one-hot encoded.
/// </summary>
Vector,
/// <summary>
/// Generate a one-hot encoding of the cell type for each cell on the board, but arranged as
/// a Rows x Columns visual observation. If there are special types, these will also be one-hot encoded.
/// </summary>
UncompressedVisual,
/// <summary>
/// Generate a one-hot encoding of the cell type for each cell on the board, but arranged as
/// a Rows x Columns visual observation. If there are special types, these will also be one-hot encoded.
/// During training, these will be sent as a concatenated series of PNG images, with 3 channels per image.
/// </summary>
CompressedVisual
}
/// <summary>
/// Sensor for Match3 games. Can generate either vector, compressed visual,
/// or uncompressed visual observations. Uses AbstractBoard.GetCellType()
/// and AbstractBoard.GetSpecialType() to determine the observation values.
/// </summary>
public class Match3Sensor : ISparseChannelSensor
{
private Match3ObservationType m_ObservationType;
private AbstractBoard m_Board;
private int[] m_Shape;
private int[] m_SparseChannelMapping;
private string m_Name;
private int m_Rows;
private int m_Columns;
private int m_NumCellTypes;
private int m_NumSpecialTypes;
private ISparseChannelSensor sparseChannelSensorImplementation;
private int SpecialTypeSize
{
get { return m_NumSpecialTypes == 0 ? 0 : m_NumSpecialTypes + 1; }
}
/// <summary>
/// Create a sensor for the board with the specified observation type.
/// </summary>
/// <param name="board"></param>
/// <param name="obsType"></param>
/// <param name="name"></param>
public Match3Sensor(AbstractBoard board, Match3ObservationType obsType, string name)
{
m_Board = board;
m_Name = name;
m_Rows = board.Rows;
m_Columns = board.Columns;
m_NumCellTypes = board.NumCellTypes;
m_NumSpecialTypes = board.NumSpecialTypes;
m_ObservationType = obsType;
m_Shape = obsType == Match3ObservationType.Vector ?
new[] { m_Rows * m_Columns * (m_NumCellTypes + SpecialTypeSize) } :
new[] { m_Rows, m_Columns, m_NumCellTypes + SpecialTypeSize };
// See comment in GetCompressedObservation()
var cellTypePaddedSize = 3 * ((m_NumCellTypes + 2) / 3);
m_SparseChannelMapping = new int[cellTypePaddedSize + SpecialTypeSize];
// If we have 4 cell types and 2 special types (3 special size), we'd have
// [0, 1, 2, 3, -1, -1, 4, 5, 6]
for (var i = 0; i < m_NumCellTypes; i++)
{
m_SparseChannelMapping[i] = i;
}
for (var i = m_NumCellTypes; i < cellTypePaddedSize; i++)
{
m_SparseChannelMapping[i] = -1;
}
for (var i = 0; i < SpecialTypeSize; i++)
{
m_SparseChannelMapping[cellTypePaddedSize + i] = i + m_NumCellTypes;
}
}
/// <inheritdoc/>
public int[] GetObservationShape()
{
return m_Shape;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
if (m_Board.Rows != m_Rows || m_Board.Columns != m_Columns || m_Board.NumCellTypes != m_NumCellTypes)
{
Debug.LogWarning(
$"Board shape changes since sensor initialization. This may cause unexpected results. " +
$"Old shape: Rows={m_Rows} Columns={m_Columns}, NumCellTypes={m_NumCellTypes} " +
$"Current shape: Rows={m_Board.Rows} Columns={m_Board.Columns}, NumCellTypes={m_Board.NumCellTypes}"
);
}
if (m_ObservationType == Match3ObservationType.Vector)
{
int offset = 0;
for (var r = 0; r < m_Rows; r++)
{
for (var c = 0; c < m_Columns; c++)
{
var val = m_Board.GetCellType(r, c);
for (var i = 0; i < m_NumCellTypes; i++)
{
writer[offset] = (i == val) ? 1.0f : 0.0f;
offset++;
}
if (m_NumSpecialTypes > 0)
{
var special = m_Board.GetSpecialType(r, c);
for (var i = 0; i < SpecialTypeSize; i++)
{
writer[offset] = (i == special) ? 1.0f : 0.0f;
offset++;
}
}
}
}
return offset;
}
else
{
// TODO combine loops? Only difference is inner-most statement.
int offset = 0;
for (var r = 0; r < m_Rows; r++)
{
for (var c = 0; c < m_Columns; c++)
{
var val = m_Board.GetCellType(r, c);
for (var i = 0; i < m_NumCellTypes; i++)
{
writer[r, c, i] = (i == val) ? 1.0f : 0.0f;
offset++;
}
if (m_NumSpecialTypes > 0)
{
var special = m_Board.GetSpecialType(r, c);
for (var i = 0; i < SpecialTypeSize; i++)
{
writer[offset] = (i == special) ? 1.0f : 0.0f;
offset++;
}
}
}
}
return offset;
}
}
/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
var height = m_Rows;
var width = m_Columns;
var tempTexture = new Texture2D(width, height, TextureFormat.RGB24, false);
var converter = new OneHotToTextureUtil(height, width);
var bytesOut = new List<byte>();
// Encode the cell types and special types as separate batches of PNGs
// This is potentially wasteful, e.g. if there are 4 cell types and 1 special type, we could
// fit in in 2 images, but we'll use 3 here (2 PNGs for the 4 cell type channels, and 1 for
// the special types). Note that we have to also implement the sparse channel mapping.
// Optimize this it later.
var numCellImages = (m_NumCellTypes + 2) / 3;
for (var i = 0; i < numCellImages; i++)
{
converter.EncodeToTexture(m_Board.GetCellType, tempTexture, 3 * i);
bytesOut.AddRange(tempTexture.EncodeToPNG());
}
var numSpecialImages = (SpecialTypeSize + 2) / 3;
for (var i = 0; i < numSpecialImages; i++)
{
converter.EncodeToTexture(m_Board.GetSpecialType, tempTexture, 3 * i);
bytesOut.AddRange(tempTexture.EncodeToPNG());
}
DestroyTexture(tempTexture);
return bytesOut.ToArray();
}
/// <inheritdoc/>
public void Update()
{
}
/// <inheritdoc/>
public void Reset()
{
}
/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
{
return m_ObservationType == Match3ObservationType.CompressedVisual ?
SensorCompressionType.PNG :
SensorCompressionType.None;
}
/// <inheritdoc/>
public string GetName()
{
return m_Name;
}
/// <inheritdoc/>
public int[] GetCompressedChannelMapping()
{
return m_SparseChannelMapping;
}
static void DestroyTexture(Texture2D texture)
{
if (Application.isEditor)
{
// Edit Mode tests complain if we use Destroy()
Object.DestroyImmediate(texture);
}
else
{
Object.Destroy(texture);
}
}
}
/// <summary>
/// Utility class for converting a 2D array of ints representing a one-hot encoding into
/// a texture, suitable for conversion to PNGs for observations.
/// Works by encoding 3 values at a time as pixels in the texture, thus it should be
/// called (maxValue + 2) / 3 times, increasing the channelOffset by 3 each time.
/// </summary>
internal class OneHotToTextureUtil
{
Color[] m_Colors;
int m_Height;
int m_Width;
private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue };
public delegate int GridValueProvider(int x, int y);
public OneHotToTextureUtil(int height, int width)
{
m_Colors = new Color[height * width];
m_Height = height;
m_Width = width;
}
public void EncodeToTexture(GridValueProvider gridValueProvider, Texture2D texture, int channelOffset)
{
var i = 0;
// There's an implicit flip converting to PNG from texture, so make sure we
// counteract that when forming the texture by iterating through h in reverse.
for (var h = m_Height - 1; h >= 0; h--)
{
for (var w = 0; w < m_Width; w++)
{
int oneHotValue = gridValueProvider(h, w);
if (oneHotValue < channelOffset || oneHotValue >= channelOffset + 3)
{
m_Colors[i++] = Color.black;
}
else
{
m_Colors[i++] = s_OneHotColors[oneHotValue - channelOffset];
}
}
}
texture.SetPixels(m_Colors);
}
}
}

3
com.unity.ml-agents.extensions/Runtime/Match3/Match3Sensor.cs.meta


fileFormatVersion: 2
guid: 795ad5f211e344e5bf3049abd9499721
timeCreated: 1600906663

43
com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs


using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Match3
{
/// <summary>
/// Sensor component for a Match3 game.
/// </summary>
public class Match3SensorComponent : SensorComponent
{
/// <summary>
/// Name of the generated Match3Sensor object.
/// Note that changing this at runtime does not affect how the Agent sorts the sensors.
/// </summary>
public string SensorName = "Match3 Sensor";
/// <summary>
/// Type of observation to generate.
/// </summary>
public Match3ObservationType ObservationType = Match3ObservationType.Vector;
/// <inheritdoc/>
public override ISensor CreateSensor()
{
var board = GetComponent<AbstractBoard>();
return new Match3Sensor(board, ObservationType, SensorName);
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
var board = GetComponent<AbstractBoard>();
if (board == null)
{
return System.Array.Empty<int>();
}
var specialSize = board.NumSpecialTypes == 0 ? 0 : board.NumSpecialTypes + 1;
return ObservationType == Match3ObservationType.Vector ?
new[] { board.Rows * board.Columns * (board.NumCellTypes + specialSize) } :
new[] { board.Rows, board.Columns, board.NumCellTypes + specialSize };
}
}
}

3
com.unity.ml-agents.extensions/Runtime/Match3/Match3SensorComponent.cs.meta


fileFormatVersion: 2
guid: 530d2f105aa145bd8a00e021bdd925fd
timeCreated: 1600906676

260
com.unity.ml-agents.extensions/Runtime/Match3/Move.cs


using System;
namespace Unity.MLAgents.Extensions.Match3
{
/// <summary>
/// Directions for a Move.
/// </summary>
public enum Direction
{
/// <summary>
/// Move up (increasing row direction).
/// </summary>
Up,
/// <summary>
/// Move down (decreasing row direction).
/// </summary>
Down, // -row direction
/// <summary>
/// Move left (decreasing column direction).
/// </summary>
Left, // -column direction
/// <summary>
/// Move right (increasing column direction).
/// </summary>
Right, // +column direction
}
/// <summary>
/// Struct that encapsulates a swap of adjacent cells.
/// A Move can be constructed from either a starting row, column, and direction,
/// or from a "move index" between 0 and NumPotentialMoves()-1.
/// Moves are enumerated as the internal edges of the game grid.
/// Left/right moves come first. There are (maxCols - 1) * maxRows of these.
/// Up/down moves are next. There are (maxRows - 1) * maxCols of these.
/// </summary>
public struct Move
{
/// <summary>
/// Index of the move, from 0 to NumPotentialMoves-1.
/// </summary>
public int MoveIndex;
/// <summary>
/// Row of the cell that will be moved.
/// </summary>
public int Row;
/// <summary>
/// Column of the cell that will be moved.
/// </summary>
public int Column;
/// <summary>
/// Direction that the cell will be moved.
/// </summary>
public Direction Direction;
/// <summary>
/// Construct a Move from its move index and the board size.
/// This is useful for iterating through all the Moves on a board, or constructing
/// the Move corresponding to an Agent decision.
/// </summary>
/// <param name="moveIndex">Must be between 0 and NumPotentialMoves(maxRows, maxCols).</param>
/// <param name="maxRows"></param>
/// <param name="maxCols"></param>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public static Move FromMoveIndex(int moveIndex, int maxRows, int maxCols)
{
if (moveIndex < 0 || moveIndex >= NumPotentialMoves(maxRows, maxCols))
{
throw new ArgumentOutOfRangeException("Invalid move index.");
}
Direction dir;
int row, col;
if (moveIndex < (maxCols - 1) * maxRows)
{
dir = Direction.Right;
col = moveIndex % (maxCols - 1);
row = moveIndex / (maxCols - 1);
}
else
{
dir = Direction.Up;
var offset = moveIndex - (maxCols - 1) * maxRows;
col = offset % maxCols;
row = offset / maxCols;
}
return new Move
{
MoveIndex = moveIndex,
Direction = dir,
Row = row,
Column = col
};
}
/// <summary>
/// Increment the Move to the next MoveIndex, and update the Row, Column, and Direction accordingly.
/// </summary>
/// <param name="maxRows"></param>
/// <param name="maxCols"></param>
public void Next(int maxRows, int maxCols)
{
var switchoverIndex = (maxCols - 1) * maxRows;
MoveIndex++;
if (MoveIndex < switchoverIndex)
{
Column++;
if (Column == maxCols - 1)
{
Row++;
Column = 0;
}
}
else if (MoveIndex == switchoverIndex)
{
// switch from moving right to moving up
Row = 0;
Column = 0;
Direction = Direction.Up;
}
else
{
Column++;
if (Column == maxCols)
{
Row++;
Column = 0;
}
}
}
/// <summary>
/// Construct a Move from the row, column, and direction.
/// </summary>
/// <param name="row"></param>
/// <param name="col"></param>
/// <param name="dir"></param>
/// <param name="maxRows"></param>
/// <param name="maxCols"></param>
/// <returns></returns>
public static Move FromPositionAndDirection(int row, int col, Direction dir, int maxRows, int maxCols)
{
// Check for out-of-bounds
if (row < 0 || row >= maxRows)
{
throw new IndexOutOfRangeException($"row was {row}, but must be between 0 and {maxRows - 1}.");
}
if (col < 0 || col >= maxCols)
{
throw new IndexOutOfRangeException($"col was {col}, but must be between 0 and {maxCols - 1}.");
}
// Check moves that would go out of bounds e.g. col == 0 and dir == Left
if (
row == 0 && dir == Direction.Down ||
row == maxRows - 1 && dir == Direction.Up ||
col == 0 && dir == Direction.Left ||
col == maxCols - 1 && dir == Direction.Right
)
{
throw new IndexOutOfRangeException($"Cannot move cell at row={row} col={col} in Direction={dir}");
}
// Normalize - only consider Right and Up
if (dir == Direction.Left)
{
dir = Direction.Right;
col = col - 1;
}
else if (dir == Direction.Down)
{
dir = Direction.Up;
row = row - 1;
}
int moveIndex;
if (dir == Direction.Right)
{
moveIndex = col + row * (maxCols - 1);
}
else
{
var offset = (maxCols - 1) * maxRows;
moveIndex = offset + col + row * maxCols;
}
return new Move
{
Row = row,
Column = col,
Direction = dir,
MoveIndex = moveIndex,
};
}
/// <summary>
/// Get the other row and column that correspond to this move.
/// </summary>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public (int Row, int Column) OtherCell()
{
switch (Direction)
{
case Direction.Up:
return (Row + 1, Column);
case Direction.Down:
return (Row - 1, Column);
case Direction.Left:
return (Row, Column - 1);
case Direction.Right:
return (Row, Column + 1);
default:
throw new ArgumentOutOfRangeException();
}
}
/// <summary>
/// Get the opposite direction of this move.
/// </summary>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public Direction OtherDirection()
{
switch (Direction)
{
case Direction.Up:
return Direction.Down;
case Direction.Down:
return Direction.Up;
case Direction.Left:
return Direction.Right;
case Direction.Right:
return Direction.Left;
default:
throw new ArgumentOutOfRangeException();
}
}
/// <summary>
/// Return the number of potential moves for a board of the given size.
/// This is equivalent to the number of internal edges in the board.
/// </summary>
/// <param name="maxRows"></param>
/// <param name="maxCols"></param>
/// <returns></returns>
public static int NumPotentialMoves(int maxRows, int maxCols)
{
return maxRows * (maxCols - 1) + (maxRows - 1) * (maxCols);
}
}
}

3
com.unity.ml-agents.extensions/Runtime/Match3/Move.cs.meta


fileFormatVersion: 2
guid: 41d6d7b9e07c4ef1ae075c74a906906b
timeCreated: 1600466100

152
com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs


using System;
using System.Collections.Generic;
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Match3;
namespace Unity.MLAgents.Extensions.Tests.Match3
{
internal class StringBoard : AbstractBoard
{
private string[] m_Board;
private string[] m_Special;
/// <summary>
/// Convert a string like "000\n010\n000" to a board representation
/// Row 0 is considered the bottom row
/// </summary>
/// <param name="newBoard"></param>
public void SetBoard(string newBoard)
{
m_Board = newBoard.Split((char[])null, StringSplitOptions.RemoveEmptyEntries);
Rows = m_Board.Length;
Columns = m_Board[0].Length;
NumCellTypes = 0;
for (var r = 0; r < Rows; r++)
{
for (var c = 0; c < Columns; c++)
{
NumCellTypes = Mathf.Max(NumCellTypes, 1 + GetCellType(r, c));
}
}
}
public void SetSpecial(string newSpecial)
{
m_Special = newSpecial.Split((char[])null, StringSplitOptions.RemoveEmptyEntries);
Debug.Assert(Rows == m_Special.Length);
Debug.Assert(Columns == m_Special[0].Length);
NumSpecialTypes = 0;
for (var r = 0; r < Rows; r++)
{
for (var c = 0; c < Columns; c++)
{
NumSpecialTypes = Mathf.Max(NumSpecialTypes, GetSpecialType(r, c));
}
}
}
public override bool MakeMove(Move m)
{
return true;
}
public override bool IsMoveValid(Move m)
{
return SimpleIsMoveValid(m);
}
public override int GetCellType(int row, int col)
{
var character = m_Board[m_Board.Length - 1 - row][col];
return (int)(character - '0');
}
public override int GetSpecialType(int row, int col)
{
var character = m_Special[m_Board.Length - 1 - row][col];
return (int)(character - '0');
}
}
public class AbstractBoardTests
{
[Test]
public void TestBoardInit()
{
var boardString =
@"000
000
010";
var gameObj = new GameObject("board");
var board = gameObj.AddComponent<StringBoard>();
board.SetBoard(boardString);
Assert.AreEqual(3, board.Rows);
Assert.AreEqual(3, board.Columns);
Assert.AreEqual(2, board.NumCellTypes);
for (var r = 0; r < 3; r++)
{
for (var c = 0; c < 3; c++)
{
var expected = (r == 0 && c == 1) ? 1 : 0;
Assert.AreEqual(expected, board.GetCellType(r, c));
}
}
}
[Test]
public void TestCheckValidMoves()
{
var gameObj = new GameObject("board");
var board = gameObj.AddComponent<StringBoard>();
var boardString =
@"0105
1024
0203
2022";
board.SetBoard(boardString);
var validMoves = new[]
{
Move.FromPositionAndDirection(2, 1, Direction.Up, board.Rows, board.Columns), // equivalent to (3, 1, Down)
Move.FromPositionAndDirection(2, 1, Direction.Left, board.Rows, board.Columns), // equivalent to (2, 0, Right)
Move.FromPositionAndDirection(2, 1, Direction.Down, board.Rows, board.Columns), // equivalent to (1, 1, Up)
Move.FromPositionAndDirection(2, 1, Direction.Right, board.Rows, board.Columns),
Move.FromPositionAndDirection(1, 1, Direction.Down, board.Rows, board.Columns),
Move.FromPositionAndDirection(1, 1, Direction.Left, board.Rows, board.Columns),
Move.FromPositionAndDirection(1, 1, Direction.Right, board.Rows, board.Columns),
Move.FromPositionAndDirection(0, 1, Direction.Left, board.Rows, board.Columns),
};
foreach (var m in validMoves)
{
Assert.IsTrue(board.IsMoveValid(m));
}
// Run through all moves and make sure those are the only valid ones
HashSet<int> validIndices = new HashSet<int>();
foreach (var m in validMoves)
{
validIndices.Add(m.MoveIndex);
}
foreach (var move in board.AllMoves())
{
var expected = validIndices.Contains(move.MoveIndex);
Assert.AreEqual(expected, board.IsMoveValid(move), $"({move.Row}, {move.Column}, {move.Direction})");
}
HashSet<int> validIndicesFromIterator = new HashSet<int>();
foreach (var move in board.ValidMoves())
{
validIndicesFromIterator.Add(move.MoveIndex);
}
Assert.IsTrue(validIndices.SetEquals(validIndicesFromIterator));
}
}
}

3
com.unity.ml-agents.extensions/Tests/Editor/Match3/AbstractBoardTests.cs.meta


fileFormatVersion: 2
guid: a6d0404471364cd5b0b86ef72e6fe653
timeCreated: 1601332740

115
com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs


using NUnit.Framework;
using Unity.MLAgents.Extensions.Match3;
using UnityEngine;
namespace Unity.MLAgents.Extensions.Tests.Match3
{
internal class SimpleBoard : AbstractBoard
{
public int LastMoveIndex;
public bool MovesAreValid = true;
public bool CallbackCalled;
public override int GetCellType(int row, int col)
{
return 0;
}
public override int GetSpecialType(int row, int col)
{
return 0;
}
public override bool IsMoveValid(Move m)
{
return MovesAreValid;
}
public override bool MakeMove(Move m)
{
LastMoveIndex = m.MoveIndex;
return MovesAreValid;
}
public void Callback()
{
CallbackCalled = true;
}
}
public class Match3ActuatorTests
{
[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
}
[TestCase(true)]
[TestCase(false)]
public void TestValidMoves(bool movesAreValid)
{
// Check that a board with no valid moves doesn't raise an exception.
var gameObj = new GameObject();
var board = gameObj.AddComponent<SimpleBoard>();
var agent = gameObj.AddComponent<Agent>();
gameObj.AddComponent<Match3ActuatorComponent>();
board.Rows = 5;
board.Columns = 5;
board.NumCellTypes = 5;
board.NumSpecialTypes = 0;
board.MovesAreValid = movesAreValid;
board.OnNoValidMovesAction = board.Callback;
board.LastMoveIndex = -1;
agent.LazyInitialize();
agent.RequestDecision();
Academy.Instance.EnvironmentStep();
if (movesAreValid)
{
Assert.IsFalse(board.CallbackCalled);
}
else
{
Assert.IsTrue(board.CallbackCalled);
}
Assert.AreNotEqual(-1, board.LastMoveIndex);
}
[Test]
public void TestActionSpec()
{
var gameObj = new GameObject();
var board = gameObj.AddComponent<SimpleBoard>();
var actuator = gameObj.AddComponent<Match3ActuatorComponent>();
board.Rows = 5;
board.Columns = 5;
board.NumCellTypes = 5;
board.NumSpecialTypes = 0;
var actionSpec = actuator.ActionSpec;
Assert.AreEqual(1, actionSpec.NumDiscreteActions);
Assert.AreEqual(board.NumMoves(), actionSpec.BranchSizes[0]);
}
[Test]
public void TestActionSpecNullBoard()
{
var gameObj = new GameObject();
var actuator = gameObj.AddComponent<Match3ActuatorComponent>();
var actionSpec = actuator.ActionSpec;
Assert.AreEqual(0, actionSpec.NumDiscreteActions);
Assert.AreEqual(0, actionSpec.NumContinuousActions);
}
}
}

3
com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3ActuatorTests.cs.meta


fileFormatVersion: 2
guid: 2edf24df24ac426085cb31a94d063683
timeCreated: 1603392289

314
com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs


using System;
using System.Collections.Generic;
using System.IO;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Extensions.Match3;
using UnityEngine;
using Unity.MLAgents.Extensions.Tests.Sensors;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Tests.Match3
{
public class Match3SensorTests
{
// Whether the expected PNG data should be written to a file.
// Only set this to true if the compressed observation format changes.
private bool WritePNGDataToFile = false;
[Test]
public void TestVectorObservations()
{
var boardString =
@"000
000
010";
var gameObj = new GameObject("board");
var board = gameObj.AddComponent<StringBoard>();
board.SetBoard(boardString);
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>();
sensorComponent.ObservationType = Match3ObservationType.Vector;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3 * 3 * 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
var expectedObs = new float[]
{
1, 0, /**/ 0, 1, /**/ 1, 0,
1, 0, /**/ 1, 0, /**/ 1, 0,
1, 0, /**/ 1, 0, /**/ 1, 0,
};
SensorTestHelper.CompareObservation(sensor, expectedObs);
}
[Test]
public void TestVectorObservationsSpecial()
{
var boardString =
@"000
000
010";
var specialString =
@"010
200
000";
var gameObj = new GameObject("board");
var board = gameObj.AddComponent<StringBoard>();
board.SetBoard(boardString);
board.SetSpecial(specialString);
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>();
sensorComponent.ObservationType = Match3ObservationType.Vector;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3 * 3 * (2 + 3) };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
var expectedObs = new float[]
{
1, 0, 1, 0, 0, /* (0, 0) */ 0, 1, 1, 0, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */
1, 0, 0, 0, 1, /* (0, 2) */ 1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 1, 0, 0, /* (0, 0) */
1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 0, 1, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */
};
SensorTestHelper.CompareObservation(sensor, expectedObs);
}
[Test]
public void TestVisualObservations()
{
var boardString =
@"000
000
010";
var gameObj = new GameObject("board");
var board = gameObj.AddComponent<StringBoard>();
board.SetBoard(boardString);
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>();
sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());
var expectedObs = new float[]
{
1, 0, /**/ 0, 1, /**/ 1, 0,
1, 0, /**/ 1, 0, /**/ 1, 0,
1, 0, /**/ 1, 0, /**/ 1, 0,
};
SensorTestHelper.CompareObservation(sensor, expectedObs);
var expectedObs3D = new float[,,]
{
{{1, 0}, {0, 1}, {1, 0}},
{{1, 0}, {1, 0}, {1, 0}},
{{1, 0}, {1, 0}, {1, 0}},
};
SensorTestHelper.CompareObservation(sensor, expectedObs3D);
}
[Test]
public void TestVisualObservationsSpecial()
{
var boardString =
@"000
000
010";
var specialString =
@"010
200
000";
var gameObj = new GameObject("board");
var board = gameObj.AddComponent<StringBoard>();
board.SetBoard(boardString);
board.SetSpecial(specialString);
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>();
sensorComponent.ObservationType = Match3ObservationType.UncompressedVisual;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(SensorCompressionType.None, sensor.GetCompressionType());
var expectedObs = new float[]
{
1, 0, 1, 0, 0, /* (0, 0) */ 0, 1, 1, 0, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */
1, 0, 0, 0, 1, /* (0, 2) */ 1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 1, 0, 0, /* (0, 0) */
1, 0, 1, 0, 0, /* (0, 0) */ 1, 0, 0, 1, 0, /* (0, 1) */ 1, 0, 1, 0, 0, /* (0, 0) */
};
SensorTestHelper.CompareObservation(sensor, expectedObs);
var expectedObs3D = new float[,,]
{
{{1, 0, 1, 0, 0}, {0, 1, 1, 0, 0}, {1, 0, 1, 0, 0}},
{{1, 0, 0, 0, 1}, {1, 0, 1, 0, 0}, {1, 0, 1, 0, 0}},
{{1, 0, 1, 0, 0}, {1, 0, 0, 1, 0}, {1, 0, 1, 0, 0}},
};
SensorTestHelper.CompareObservation(sensor, expectedObs3D);
}
[Test]
public void TestCompressedVisualObservations()
{
var boardString =
@"000
000
010";
var gameObj = new GameObject("board");
var board = gameObj.AddComponent<StringBoard>();
board.SetBoard(boardString);
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>();
sensorComponent.ObservationType = Match3ObservationType.CompressedVisual;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3, 3, 2 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());
var pngData = sensor.GetCompressedObservation();
if (WritePNGDataToFile)
{
// Enable this if the format of the observation changes
SavePNGs(pngData, "match3obs");
}
var expectedPng = LoadPNGs("match3obs", 1);
Assert.AreEqual(expectedPng, pngData);
}
[Test]
public void TestCompressedVisualObservationsSpecial()
{
var boardString =
@"000
000
010";
var specialString =
@"010
200
000";
var gameObj = new GameObject("board");
var board = gameObj.AddComponent<StringBoard>();
board.SetBoard(boardString);
board.SetSpecial(specialString);
var sensorComponent = gameObj.AddComponent<Match3SensorComponent>();
sensorComponent.ObservationType = Match3ObservationType.CompressedVisual;
var sensor = sensorComponent.CreateSensor();
var expectedShape = new[] { 3, 3, 2 + 3 };
Assert.AreEqual(expectedShape, sensorComponent.GetObservationShape());
Assert.AreEqual(expectedShape, sensor.GetObservationShape());
Assert.AreEqual(SensorCompressionType.PNG, sensor.GetCompressionType());
var concatenatedPngData = sensor.GetCompressedObservation();
var pathPrefix = "match3obs_special";
if (WritePNGDataToFile)
{
// Enable this if the format of the observation changes
SavePNGs(concatenatedPngData, pathPrefix);
}
var expectedPng = LoadPNGs(pathPrefix, 2);
Assert.AreEqual(expectedPng, concatenatedPngData);
}
/// <summary>
/// Helper method for un-concatenating PNG observations.
/// </summary>
/// <param name="concatenated"></param>
/// <returns></returns>
List<byte[]> SplitPNGs(byte[] concatenated)
{
var pngsOut = new List<byte[]>();
var pngHeader = new byte[] { 137, 80, 78, 71, 13, 10, 26, 10 };
var current = new List<byte>();
for (var i = 0; i < concatenated.Length; i++)
{
current.Add(concatenated[i]);
// Check if the header starts at the next position
// If so, we'll start a new output array.
var headerIsNext = false;
if (i + 1 < concatenated.Length - pngHeader.Length)
{
for (var j = 0; j < pngHeader.Length; j++)
{
if (concatenated[i + 1 + j] != pngHeader[j])
{
break;
}
if (j == pngHeader.Length - 1)
{
headerIsNext = true;
}
}
}
if (headerIsNext)
{
pngsOut.Add(current.ToArray());
current = new List<byte>();
}
}
pngsOut.Add(current.ToArray());
return pngsOut;
}
void SavePNGs(byte[] concatenatedPngData, string pathPrefix)
{
var splitPngs = SplitPNGs(concatenatedPngData);
for (var i = 0; i < splitPngs.Count; i++)
{
var pngData = splitPngs[i];
var path = $"Packages/com.unity.ml-agents.extensions/Tests/Editor/Match3/{pathPrefix}{i}.png";
using (var sw = File.Create(path))
{
foreach (var b in pngData)
{
sw.WriteByte(b);
}
}
}
}
byte[] LoadPNGs(string pathPrefix, int numExpected)
{
var bytesOut = new List<byte>();
for (var i = 0; i < numExpected; i++)
{
var path = $"Packages/com.unity.ml-agents.extensions/Tests/Editor/Match3/{pathPrefix}{i}.png";
var res = File.ReadAllBytes(path);
bytesOut.AddRange(res);
}
return bytesOut.ToArray();
}
}
}

3
com.unity.ml-agents.extensions/Tests/Editor/Match3/Match3SensorTests.cs.meta


fileFormatVersion: 2
guid: dfe94a9d6e994f408cb97d07dd44c994
timeCreated: 1603493723

60
com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs


using System;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Match3;
namespace Unity.MLAgents.Extensions.Tests.Match3
{
public class MoveTests
{
[Test]
public void TestMoveEquivalence()
{
var moveUp = Move.FromPositionAndDirection(1, 1, Direction.Up, 10, 10);
var moveDown = Move.FromPositionAndDirection(2, 1, Direction.Down, 10, 10);
Assert.AreEqual(moveUp.MoveIndex, moveDown.MoveIndex);
var moveRight = Move.FromPositionAndDirection(1, 1, Direction.Right, 10, 10);
var moveLeft = Move.FromPositionAndDirection(1, 2, Direction.Left, 10, 10);
Assert.AreEqual(moveRight.MoveIndex, moveLeft.MoveIndex);
}
[Test]
public void TestNext()
{
var maxRows = 8;
var maxCols = 13;
// make sure using Next agrees with FromMoveIndex.
var advanceMove = Move.FromMoveIndex(0, maxRows, maxCols);
for (var moveIndex = 0; moveIndex < Move.NumPotentialMoves(maxRows, maxCols); moveIndex++)
{
var moveFromIndex = Move.FromMoveIndex(moveIndex, maxRows, maxCols);
Assert.AreEqual(advanceMove.MoveIndex, moveFromIndex.MoveIndex);
Assert.AreEqual(advanceMove.Row, moveFromIndex.Row);
Assert.AreEqual(advanceMove.Column, moveFromIndex.Column);
Assert.AreEqual(advanceMove.Direction, moveFromIndex.Direction);
advanceMove.Next(maxRows, maxCols);
}
}
// These are off the board
[TestCase(-1, 5, Direction.Up)]
[TestCase(10, 5, Direction.Up)]
[TestCase(5, -1, Direction.Up)]
[TestCase(5, 10, Direction.Up)]
// These are on the board but would move off
[TestCase(0, 5, Direction.Down)]
[TestCase(9, 5, Direction.Up)]
[TestCase(5, 0, Direction.Left)]
[TestCase(5, 9, Direction.Right)]
public void TestInvalidMove(int row, int col, Direction dir)
{
int numRows = 10, numCols = 10;
Assert.Throws<IndexOutOfRangeException>(() =>
{
Move.FromPositionAndDirection(row, col, dir, numRows, numCols);
});
}
}
}

3
com.unity.ml-agents.extensions/Tests/Editor/Match3/MoveTests.cs.meta


fileFormatVersion: 2
guid: 42981032af6f4241ae20fe24e898f60b
timeCreated: 1601336681

3
com.unity.ml-agents.extensions/Tests/Editor/Match3/match3obs0.png

之前 之后
宽度: 3  |  高度: 3  |  大小: 81 B

部分文件因为文件数量过多而无法显示

正在加载...
取消
保存