浏览代码

Adding a fully connected visual encoder for super small visual input + tests (#5351)

* initial commit for a fully connected visual encoder

* adding a test

* addressing comments

* Fixing error with minimal size of fully connected network

* adding documentation and changelog
/colab-links
GitHub 3 年前
当前提交
bb07eb45
共有 6 个文件被更改,包括 71 次插入2 次删除
  1. 1
      com.unity.ml-agents/CHANGELOG.md
  2. 2
      docs/Training-Configuration-File.md
  3. 1
      ml-agents/mlagents/trainers/settings.py
  4. 42
      ml-agents/mlagents/trainers/tests/torch/test_encoders.py
  5. 24
      ml-agents/mlagents/trainers/torch/encoders.py
  6. 3
      ml-agents/mlagents/trainers/torch/utils.py

1
com.unity.ml-agents/CHANGELOG.md


### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Added a fully connected visual encoder for environments with very small image inputs. (#5351)
### Bug Fixes

2
docs/Training-Configuration-File.md


| `network_settings -> hidden_units` | (default = `128`) Number of units in the hidden layers of the neural network. Correspond to how many units are in each fully connected layer of the neural network. For simple problems where the correct action is a straightforward combination of the observation inputs, this should be small. For problems where the action is a very complex interaction between the observation variables, this should be larger. <br><br> Typical range: `32` - `512` |
| `network_settings -> num_layers` | (default = `2`) The number of hidden layers in the neural network. Corresponds to how many hidden layers are present after the observation input, or after the CNN encoding of the visual observation. For simple problems, fewer layers are likely to train faster and more efficiently. More layers may be necessary for more complex control problems. <br><br> Typical range: `1` - `3` |
| `network_settings -> normalize` | (default = `false`) Whether normalization is applied to the vector observation inputs. This normalization is based on the running average and variance of the vector observation. Normalization can be helpful in cases with complex continuous control problems, but may be harmful with simpler discrete control problems. |
| `network_settings -> vis_encode_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. |
| `network_settings -> vis_encode_type` | (default = `simple`) Encoder type for encoding visual observations. <br><br> `simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. `fully_connected` uses a single fully connected dense layer as encoder and should be reserved for very small inputs. |
| `network_settings -> conditioning_type` | (default = `hyper`) Conditioning type for the policy using goal observations. <br><br> `none` treats the goal observations as regular observations, `hyper` (default) uses a HyperNetwork with goal observations as input to generate some of the weights of the policy. Note that when using `hyper` the number of parameters of the network increases greatly. Therefore, it is recommended to reduce the number of `hidden_units` when using this `conditioning_type`

1
ml-agents/mlagents/trainers/settings.py


class EncoderType(Enum):
FULLY_CONNECTED = "fully_connected"
MATCH3 = "match3"
SIMPLE = "simple"
NATURE_CNN = "nature_cnn"

42
ml-agents/mlagents/trainers/tests/torch/test_encoders.py


from mlagents.trainers.torch.encoders import (
VectorInput,
Normalizer,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,

@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)])
@pytest.mark.parametrize(
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder]
"vis_class",
[
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
],
)
def test_visual_encoder(vis_class, image_size):
num_outputs = 128

encoding = enc(sample_input)
assert encoding.shape == (1, num_outputs)
@pytest.mark.parametrize(
"vis_class, size",
[
(SimpleVisualEncoder, 36),
(ResNetVisualEncoder, 36),
(NatureVisualEncoder, 36),
(SmallVisualEncoder, 10),
(FullyConnectedVisualEncoder, 36),
],
)
def test_visual_encoder_trains(vis_class, size):
torch.manual_seed(0)
image_size = (size, size, 1)
batch = 100
inputs = torch.cat(
[torch.zeros((batch,) + image_size), torch.ones((batch,) + image_size)], dim=0
)
target = torch.cat([torch.zeros((batch,)), torch.ones((batch,))], dim=0)
enc = vis_class(image_size[0], image_size[1], image_size[2], 1)
optimizer = torch.optim.Adam(enc.parameters(), lr=0.001)
for _ in range(15):
prediction = enc(inputs)[:, 0]
loss = torch.mean((target - prediction) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
assert loss.item() < 0.05

24
ml-agents/mlagents/trainers/torch/encoders.py


self.normalizer.update(inputs)
class FullyConnectedVisualEncoder(nn.Module):
def __init__(
self, height: int, width: int, initial_channels: int, output_size: int
):
super().__init__()
self.output_size = output_size
self.input_size = height * width * initial_channels
self.dense = nn.Sequential(
linear_layer(
self.input_size,
self.output_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.41, # Use ReLU gain
),
nn.LeakyReLU(),
)
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
if not exporting_to_onnx.is_exporting():
visual_obs = visual_obs.permute([0, 3, 1, 2])
hidden = visual_obs.reshape(-1, self.input_size)
return self.dense(hidden)
class SmallVisualEncoder(nn.Module):
"""
CNN architecture used by King in their Candy Crush predictor

3
ml-agents/mlagents/trainers/torch/utils.py


ResNetVisualEncoder,
NatureVisualEncoder,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType

# Minimum supported side for each encoder type. If refactoring an encoder, please
# adjust these also.
MIN_RESOLUTION_FOR_ENCODER = {
EncoderType.FULLY_CONNECTED: 1,
EncoderType.MATCH3: 5,
EncoderType.SIMPLE: 20,
EncoderType.NATURE_CNN: 36,

EncoderType.NATURE_CNN: NatureVisualEncoder,
EncoderType.RESNET: ResNetVisualEncoder,
EncoderType.MATCH3: SmallVisualEncoder,
EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder,
}
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)

正在加载...
取消
保存