浏览代码

Merge branch 'main' into develop-api-documentation-update

Syncing with main.
/develop/api-documentation-update-some-fixes
Miguel Alonso Jr 3 年前
当前提交
97b7d5c6
共有 18 个文件被更改,包括 1281 次插入100 次删除
  1. 2
      .yamato/com.unity.ml-agents-optional-dep-tests.yml
  2. 29
      LICENSE
  3. 29
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs
  4. 28
      com.unity.ml-agents.extensions/LICENSE.md
  5. 1
      com.unity.ml-agents/CHANGELOG.md
  6. 28
      com.unity.ml-agents/LICENSE.md
  7. 38
      com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs
  8. 13
      com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs
  9. 8
      config/sac/Hallway.yaml
  10. 58
      docs/Learning-Environment-Design-Agents.md
  11. 2
      docs/Training-Configuration-File.md
  12. 48
      gym-unity/README.md
  13. 1
      ml-agents/mlagents/trainers/settings.py
  14. 42
      ml-agents/mlagents/trainers/tests/torch/test_encoders.py
  15. 23
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  16. 24
      ml-agents/mlagents/trainers/torch/encoders.py
  17. 6
      ml-agents/mlagents/trainers/torch/utils.py
  18. 1001
      docs/images/grid_sensor.png

2
.yamato/com.unity.ml-agents-optional-dep-tests.yml


pull_request.target match "release.+") AND
NOT pull_request.draft AND
(pull_request.changes.any match "com.unity.ml-agents/**" OR
pull_request.changes.any match ".yamato/com.unity.ml-agents-test.yml")
pull_request.changes.any match ".yamato/com.unity.ml-agents-optional-dep-tests.yml")
{% endfor %}

29
LICENSE


Copyright 2017-2021 Unity Technologies
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/

file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2017 Unity Technologies
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

29
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs


string m_BehaviorNameOverrideDirectory;
private string m_OriginalBehaviorName;
private List<string> m_OverrideExtensions = new List<string>();
// Cached loaded NNModels, with the behavior name as the key.

{
GetAssetPathFromCommandLine();
return !string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory);
}
}
/// <summary>
/// The original behavior name of the agent. The actual behavior name will change when it is overridden.
/// </summary>
public string OriginalBehaviorName
{
get
{
if (string.IsNullOrEmpty(m_OriginalBehaviorName))
{
var bp = m_Agent.GetComponent<BehaviorParameters>();
m_OriginalBehaviorName = bp.BehaviorName;
}
return m_OriginalBehaviorName;
}
}

string overrideError = null;
m_Agent.LazyInitialize();
var bp = m_Agent.GetComponent<BehaviorParameters>();
var behaviorName = bp.BehaviorName;
nnModel = GetModelForBehaviorName(behaviorName);
nnModel = GetModelForBehaviorName(OriginalBehaviorName);
}
catch (Exception e)
{

if (string.IsNullOrEmpty(overrideError))
{
overrideError =
$"Didn't find a model for behaviorName {behaviorName}. Make " +
$"Didn't find a model for behaviorName {OriginalBehaviorName}. Make " +
"sure the behaviorName is set correctly in the commandline " +
"and that the model file exists";
}

var modelName = nnModel != null ? nnModel.name : "<null>";
Debug.Log($"Overriding behavior {behaviorName} for agent with model {modelName}");
Debug.Log($"Overriding behavior {OriginalBehaviorName} for agent with model {modelName}");
m_Agent.SetModel(GetOverrideBehaviorName(behaviorName), nnModel);
m_Agent.SetModel(GetOverrideBehaviorName(OriginalBehaviorName), nnModel);
overrideOk = true;
}
catch (Exception e)

28
com.unity.ml-agents.extensions/LICENSE.md


com.unity.ml-agents.extensions copyright © 2020 Unity Technologies ApS
Copyright 2020-2021 Unity Technologies
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Apache License

file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2017 Unity Technologies
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

1
com.unity.ml-agents/CHANGELOG.md


### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Added a fully connected visual encoder for environments with very small image inputs. (#5351)
### Bug Fixes

28
com.unity.ml-agents/LICENSE.md


com.unity.ml-agents copyright © 2020 Unity Technologies ApS
Copyright 2017-2021 Unity Technologies
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Apache License

file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2017 Unity Technologies
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

38
com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs


using System.Collections.Generic;
using System.Linq;
using UnityEngine;
namespace Unity.MLAgents.Sensors

{
// dummy sensor only used for debug gizmo
GridSensorBase m_DebugSensor;
List<ISensor> m_Sensors;
List<GridSensorBase> m_Sensors;
internal BoxOverlapChecker m_BoxOverlapChecker;
[HideInInspector, SerializeField]

/// <inheritdoc/>
public override ISensor[] CreateSensors()
{
m_Sensors = new List<ISensor>();
m_BoxOverlapChecker = new BoxOverlapChecker(
m_CellScale,
m_GridSize,

m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None);
m_BoxOverlapChecker.RegisterDebugSensor(m_DebugSensor);
var gridSensors = GetGridSensors();
if (gridSensors == null || gridSensors.Length < 1)
m_Sensors = GetGridSensors().ToList();
if (m_Sensors == null || m_Sensors.Count < 1)
foreach (var sensor in gridSensors)
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
m_Sensors[0].m_BoxOverlapChecker = m_BoxOverlapChecker;
foreach (var sensor in m_Sensors)
if (ObservationStacks != 1)
m_BoxOverlapChecker.RegisterSensor(sensor);
}
if (ObservationStacks != 1)
{
var sensors = new ISensor[m_Sensors.Count];
for (var i = 0; i < m_Sensors.Count; i++)
m_Sensors.Add(new StackingSensor(sensor, ObservationStacks));
sensors[i] = new StackingSensor(m_Sensors[i], ObservationStacks);
else
{
m_Sensors.Add(sensor);
}
m_BoxOverlapChecker.RegisterSensor(sensor);
return sensors;
}
else
{
return m_Sensors.ToArray();
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
((GridSensorBase)m_Sensors[0]).m_BoxOverlapChecker = m_BoxOverlapChecker;
return m_Sensors.ToArray();
}
/// <summary>

m_BoxOverlapChecker.ColliderMask = m_ColliderMask;
foreach (var sensor in m_Sensors)
{
((GridSensorBase)sensor).CompressionType = m_CompressionType;
sensor.CompressionType = m_CompressionType;
}
}
}

13
com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs


gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true);
gridSensorComponent.CreateSensors();
var componentSensor = (List<ISensor>)typeof(GridSensorComponent).GetField("m_Sensors",
var componentSensor = (List<GridSensorBase>)typeof(GridSensorComponent).GetField("m_Sensors",
BindingFlags.Instance | BindingFlags.NonPublic).GetValue(gridSensorComponent);
Assert.AreEqual(componentSensor.Count, 1);
}

{
gridSensorComponent.CreateSensors();
});
}
[Test]
public void TestStackedSensors()
{
testGo.tag = k_Tag2;
string[] tags = { k_Tag1, k_Tag2 };
gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true);
gridSensorComponent.ObservationStacks = 3;
var sensors = gridSensorComponent.CreateSensors();
Assert.IsInstanceOf(typeof(StackingSensor), sensors[0]);
}
}
}

8
config/sac/Hallway.yaml


hyperparameters:
learning_rate: 0.0003
learning_rate_schedule: constant
batch_size: 128
buffer_size: 50000
batch_size: 512
buffer_size: 200000
buffer_init_steps: 0
tau: 0.005
steps_per_update: 10.0

num_layers: 2
vis_encode_type: simple
memory:
sequence_length: 32
sequence_length: 64
memory_size: 128
reward_signals:
extrinsic:

max_steps: 5000000
max_steps: 4000000
time_horizon: 64
summary_freq: 10000

58
docs/Learning-Environment-Design-Agents.md


- Use as few rays and tags as necessary to solve the problem in order to improve
learning stability and agent performance.
### Grid Observations
Grid-base observations combine the advantages of 2D spatial representation in
visual observations, and the flexibility of defining detectable objects in
RayCast observations. The sensor uses a set of box queries in a grid shape and
gives a top-down 2D view around the agent. This can be implemented by adding a
`GridSensorComponent` to the Agent GameObject.
During observations, the sensor detects the presence of detectable objects in
each cell and encode that into one-hot representation. The collected information
from each cell forms a 3D tensor observation and will be fed into the
convolutional neural network (CNN) of the agent policy just like visual
observations.
![Agent with GridSensorComponent](images/grid_sensor.png)
The sensor component has the following settings:
- _Cell Scale_ The scale of each cell in the grid.
- _Grid Size_ Number of cells on each side of the grid.
- _Agent Game Object_ The Agent that holds the grid sensor. This is used to
disambiguate objects with the same tag as the agent so that the agent doesn't
detect itself.
- _Rotate With Agent_ Whether the grid rotates with the Agent.
- _Detectable Tags_ A list of strings corresponding to the types of objects that
the Agent should be able to distinguish between.
- _Collider Mask_ The [LayerMask](https://docs.unity3d.com/ScriptReference/LayerMask.html)
passed to the collider detection. This can be used to ignore certain types
of objects.
- _Initial Collider Buffer Size_ The initial size of the Collider buffer used
in the non-allocating Physics calls for each cell.
- _Max Collider Buffer Size_ The max size of the Collider buffer used in the
non-allocating Physics calls for each cell.
The observation for each grid cell is a one-hot encoding of the detected object.
The total size of the created observations is
```
GridSize.x * GridSize.z * Num Detectable Tags
```
so the number of detectable tags and size of the grid should be kept as small as
possible to reduce the amount of data used. This makes a trade-off between the
granularity of the observation and training speed.
To allow more variety of observations that grid sensor can capture, the
`GridSensorComponent` and the underlying `GridSensorBase` also provides interfaces
that can be overridden to collect customized observation from detected objects.
See the doc on
[extending grid Sensors](https://github.com/Unity-Technologies/ml-agents/blob/release_17/com.unity.ml-agents.extensions/Documentation~/CustomGridSensors.md)
for more details on custom grid sensors.
#### Grid Observation Summary & Best Practices
- Attach `GridSensorComponent` to use.
- This observation type is best used when there is relevant non-visual spatial information that
can be best captured in 2D representations.
- Use as small grid size and as few tags as necessary to solve the problem in order to improve
learning stability and agent performance.
### Variable Length Observations
It is possible for agents to collect observations from a varying number of

2
docs/Training-Configuration-File.md


| `network_settings -> hidden_units` | (default = `128`) Number of units in the hidden layers of the neural network. Correspond to how many units are in each fully connected layer of the neural network. For simple problems where the correct action is a straightforward combination of the observation inputs, this should be small. For problems where the action is a very complex interaction between the observation variables, this should be larger. <br><br> Typical range: `32` - `512` |
| `network_settings -> num_layers` | (default = `2`) The number of hidden layers in the neural network. Corresponds to how many hidden layers are present after the observation input, or after the CNN encoding of the visual observation. For simple problems, fewer layers are likely to train faster and more efficiently. More layers may be necessary for more complex control problems. <br><br> Typical range: `1` - `3` |
| `network_settings -> normalize` | (default = `false`) Whether normalization is applied to the vector observation inputs. This normalization is based on the running average and variance of the vector observation. Normalization can be helpful in cases with complex continuous control problems, but may be harmful with simpler discrete control problems. |
| `network_settings -> vis_encode_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. |
| `network_settings -> vis_encode_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. `fully_connected` uses a single fully connected dense layer as encoder and should be reserved for very small inputs. |
| `network_settings -> conditioning_type` | (default = `hyper`) Conditioning type for the policy using goal observations. <br><br> `none` treats the goal observations as regular observations, `hyper` (default) uses a HyperNetwork with goal observations as input to generate some of the weights of the policy. Note that when using `hyper` the number of parameters of the network increases greatly. Therefore, it is recommended to reduce the number of `hidden_units` when using this `conditioning_type`

48
gym-unity/README.md


Discrete. Otherwise, it will be converted into a MultiDiscrete. Defaults to
`False`.
- `allow_multiple_obs` will return a list of observations. The first elements contain the visual observations and the
last element contains the array of vector observations. If False the environment returns a single array (containing
a single visual observations, if present, otherwise the vector observation). Defaults to `False`.
- `allow_multiple_obs` will return a list of observations. The first elements
contain the visual observations and the last element contains the array of
vector observations. If False the environment returns a single array (containing
a single visual observations, if present, otherwise the vector observation).
Defaults to `False`.
- `action_space_seed` is the optional seed for action sampling. If non-None, will
be used to set the random seed on created gym.Space instances.
The returned environment `env` will function as a gym.

```
Next, create a file called `train_unity.py`. Then create an `/envs/` directory
and build the GridWorld environment to that directory. For more information on
and build the environment to that directory. For more information on
[here](../docs/Learning-Environment-Executable.md). Add the following code to
the `train_unity.py` file:
[here](../docs/Learning-Environment-Executable.md). Note that because of
limitations of the DQN baseline, the environment must have a single visual
observation, a single discrete action and a single Agent in the scene.
Add the following code to the `train_unity.py` file:
```python
import gym

from gym_unity.envs import UnityToGymWrapper
def main():
unity_env = UnityEnvironment("./envs/GridWorld")
env = UnityToGymWrapper(unity_env, 0, uint8_visual=True)
unity_env = UnityEnvironment(<path-to-environment>)
env = UnityToGymWrapper(unity_env, uint8_visual=True)
"cnn", # conv_only is also a good choice for GridWorld
"cnn", # For visual inputs
lr=2.5e-4,
total_timesteps=1000000,
buffer_size=50000,

"""
def make_env(rank, use_visual=True): # pylint: disable=C0111
def _thunk():
unity_env = UnityEnvironment(env_directory)
env = UnityToGymWrapper(unity_env, rank, uint8_visual=True)
unity_env = UnityEnvironment(env_directory, base_port=5000 + rank)
env = UnityToGymWrapper(unity_env, uint8_visual=True)
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
return env
return _thunk

return DummyVecEnv([make_env(rank, use_visual=False)])
def main():
env = make_unity_env('./envs/GridWorld', 4, True)
env = make_unity_env(<path-to-environment>, 4, True)
ppo2.learn(
network="mlp",
env=env,

```python
game_version = 'v0' if sticky_actions else 'v4'
full_game_name = '{}NoFrameskip-{}'.format(game_name, game_version)
unity_env = UnityEnvironment('./envs/GridWorld')
unity_env = UnityEnvironment(<path-to-environment>)
`./envs/GridWorld` is the path to your built Unity executable. For more
`<path-to-environment>` is the path to your built Unity executable. For more
information on building Unity environments, see
[here](../docs/Learning-Environment-Executable.md), and note the Limitations
section below.

Since Dopamine is designed around variants of DQN, it is only compatible with
discrete action spaces, and specifically the Discrete Gym space. For
environments that use branched discrete action spaces (e.g.
[VisualBanana](../docs/Learning-Environment-Examples.md)), you can enable the
environments that use branched discrete action spaces, you can enable the
`flatten_branched` parameter in `UnityToGymWrapper`, which treats each
combination of branched actions as separate actions.

The hyperparameters provided by Dopamine are tailored to the Atari games, and
you will likely need to adjust them for ML-Agents environments. Here is a sample
`dopamine/agents/rainbow/configs/rainbow.gin` file that is known to work with
GridWorld.
a simple GridWorld.
```python
import dopamine.agents.rainbow.rainbow_agent

![Dopamine on GridWorld](images/dopamine_gridworld_plot.png)
### Example: VisualBanana
As an example of using the `flatten_branched` option, we also used the Rainbow
algorithm to train on the VisualBanana environment, and provide the results
below. The same hyperparameters were used as in the GridWorld case, except that
`replay_history` and `epsilon_decay` were increased to 100000.
![Dopamine on VisualBanana](images/dopamine_visualbanana_plot.png)

1
ml-agents/mlagents/trainers/settings.py


class EncoderType(Enum):
FULLY_CONNECTED = "fully_connected"
MATCH3 = "match3"
SIMPLE = "simple"
NATURE_CNN = "nature_cnn"

42
ml-agents/mlagents/trainers/tests/torch/test_encoders.py


from mlagents.trainers.torch.encoders import (
VectorInput,
Normalizer,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,

@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)])
@pytest.mark.parametrize(
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder]
"vis_class",
[
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
],
)
def test_visual_encoder(vis_class, image_size):
num_outputs = 128

encoding = enc(sample_input)
assert encoding.shape == (1, num_outputs)
@pytest.mark.parametrize(
"vis_class, size",
[
(SimpleVisualEncoder, 36),
(ResNetVisualEncoder, 36),
(NatureVisualEncoder, 36),
(SmallVisualEncoder, 10),
(FullyConnectedVisualEncoder, 36),
],
)
def test_visual_encoder_trains(vis_class, size):
torch.manual_seed(0)
image_size = (size, size, 1)
batch = 100
inputs = torch.cat(
[torch.zeros((batch,) + image_size), torch.ones((batch,) + image_size)], dim=0
)
target = torch.cat([torch.zeros((batch,)), torch.ones((batch,))], dim=0)
enc = vis_class(image_size[0], image_size[1], image_size[2], 1)
optimizer = torch.optim.Adam(enc.parameters(), lr=0.001)
for _ in range(15):
prediction = enc(inputs)[:, 0]
loss = torch.mean((target - prediction) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
assert loss.item() < 0.05

23
ml-agents/mlagents/trainers/tests/torch/test_utils.py


enc.forward(vis_input)
@pytest.mark.parametrize(
"encoder_type",
[
EncoderType.SIMPLE,
EncoderType.NATURE_CNN,
EncoderType.SIMPLE,
EncoderType.MATCH3,
],
)
def test_invalid_visual_input_size(encoder_type):
with pytest.raises(UnityTrainerException):
obs_spec = create_observation_specs_with_shapes(
[
(
ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1,
ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type],
1,
)
]
)
ModelUtils.create_input_processors(obs_spec, 20, encoder_type, 20, False)
@pytest.mark.parametrize("num_visual", [0, 1, 2])
@pytest.mark.parametrize("num_vector", [0, 1, 2])
@pytest.mark.parametrize("normalize", [True, False])

24
ml-agents/mlagents/trainers/torch/encoders.py


self.normalizer.update(inputs)
class FullyConnectedVisualEncoder(nn.Module):
def __init__(
self, height: int, width: int, initial_channels: int, output_size: int
):
super().__init__()
self.output_size = output_size
self.input_size = height * width * initial_channels
self.dense = nn.Sequential(
linear_layer(
self.input_size,
self.output_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.41, # Use ReLU gain
),
nn.LeakyReLU(),
)
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
if not exporting_to_onnx.is_exporting():
visual_obs = visual_obs.permute([0, 3, 1, 2])
hidden = visual_obs.reshape(-1, self.input_size)
return self.dense(hidden)
class SmallVisualEncoder(nn.Module):
"""
CNN architecture used by King in their Candy Crush predictor

6
ml-agents/mlagents/trainers/torch/utils.py


ResNetVisualEncoder,
NatureVisualEncoder,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType

# Minimum supported side for each encoder type. If refactoring an encoder, please
# adjust these also.
MIN_RESOLUTION_FOR_ENCODER = {
EncoderType.FULLY_CONNECTED: 1,
EncoderType.MATCH3: 5,
EncoderType.SIMPLE: 20,
EncoderType.NATURE_CNN: 36,

EncoderType.NATURE_CNN: NatureVisualEncoder,
EncoderType.RESNET: ResNetVisualEncoder,
EncoderType.MATCH3: SmallVisualEncoder,
EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder,
}
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)

# VISUAL
if dim_prop in ModelUtils.VALID_VISUAL_PROP:
visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type)
ModelUtils._check_resolution_for_encoder(
shape[0], shape[1], vis_encode_type
)
return (visual_encoder_class(shape[0], shape[1], shape[2], h_size), h_size)
# VECTOR
if dim_prop in ModelUtils.VALID_VECTOR_PROP:

1001
docs/images/grid_sensor.png
文件差异内容过多而无法显示
查看文件

正在加载...
取消
保存