|
|
|
|
|
|
from typing import Tuple, Optional |
|
|
|
from typing import Tuple, Optional, Union |
|
|
|
|
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
from mlagents.trainers.torch.layers import linear_layer, Initialization, Swish |
|
|
|
|
|
|
self.running_variance.copy_(other_normalizer.running_variance.data) |
|
|
|
|
|
|
|
|
|
|
|
def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1): |
|
|
|
def conv_output_shape( |
|
|
|
h_w: Tuple[int, int], |
|
|
|
kernel_size: Union[int, Tuple[int, int]] = 1, |
|
|
|
stride: int = 1, |
|
|
|
padding: int = 0, |
|
|
|
dilation: int = 1, |
|
|
|
) -> Tuple[int, int]: |
|
|
|
""" |
|
|
|
Calculates the output shape (height and width) of the output of a convolution layer. |
|
|
|
kernel_size, stride, padding and dilation correspond to the inputs of the |
|
|
|
torch.nn.Conv2d layer (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) |
|
|
|
:param h_w: The height and width of the input. |
|
|
|
:param kernel_size: The size of the kernel of the convolution (can be an int or a |
|
|
|
tuple [width, height]) |
|
|
|
:param stride: The stride of the convolution |
|
|
|
:param padding: The padding of the convolution |
|
|
|
:param dilation: The dilation of the convolution |
|
|
|
""" |
|
|
|
if type(kernel_size) is not tuple: |
|
|
|
kernel_size = (kernel_size, kernel_size) |
|
|
|
if not isinstance(kernel_size, tuple): |
|
|
|
kernel_size = (int(kernel_size), int(kernel_size)) |
|
|
|
((h_w[0] + (2 * pad) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1 |
|
|
|
((h_w[0] + (2 * padding) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1 |
|
|
|
((h_w[1] + (2 * pad) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1 |
|
|
|
((h_w[1] + (2 * padding) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1 |
|
|
|
""" |
|
|
|
Calculates the output shape (height and width) of the output of a max pooling layer. |
|
|
|
kernel_size corresponds to the inputs of the |
|
|
|
torch.nn.MaxPool2d layer (https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html) |
|
|
|
:param kernel_size: The size of the kernel of the convolution |
|
|
|
""" |
|
|
|
height = (h_w[0] - kernel_size) // 2 + 1 |
|
|
|
width = (h_w[1] - kernel_size) // 2 + 1 |
|
|
|
return height, width |
|
|
|
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
|
class ResNetBlock(nn.Module): |
|
|
|
def __init__(self, channel: int): |
|
|
|
""" |
|
|
|
Creates a ResNet Block. |
|
|
|
:param channel: The number of channels in the input (and output) tensors of the |
|
|
|
convolutions |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
self.layers = nn.Sequential( |
|
|
|
Swish(), |
|
|
|
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), |
|
|
|
Swish(), |
|
|
|
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: |
|
|
|
return input_tensor + self.layers(input_tensor) |
|
|
|
|
|
|
|
|
|
|
|
class ResNetVisualEncoder(nn.Module): |
|
|
|
def __init__(self, height, width, initial_channels, final_hidden): |
|
|
|
super().__init__() |
|
|
|
|
|
|
self.layers.append(nn.MaxPool2d([3, 3], [2, 2])) |
|
|
|
height, width = pool_out_shape((height, width), 3) |
|
|
|
for _ in range(n_blocks): |
|
|
|
self.layers.append(self.make_block(channel)) |
|
|
|
self.layers.append(ResNetBlock(channel)) |
|
|
|
last_channel = channel |
|
|
|
self.layers.append(Swish()) |
|
|
|
self.dense = linear_layer( |
|
|
|
|
|
|
kernel_gain=1.0, |
|
|
|
) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def make_block(channel): |
|
|
|
block_layers = [ |
|
|
|
Swish(), |
|
|
|
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), |
|
|
|
Swish(), |
|
|
|
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1), |
|
|
|
] |
|
|
|
return block_layers |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def forward_block(input_hidden, block_layers): |
|
|
|
hidden = input_hidden |
|
|
|
for layer in block_layers: |
|
|
|
hidden = layer(hidden) |
|
|
|
return hidden + input_hidden |
|
|
|
|
|
|
|
if isinstance(layer, nn.Module): |
|
|
|
hidden = layer(hidden) |
|
|
|
elif isinstance(layer, list): |
|
|
|
hidden = self.forward_block(hidden, layer) |
|
|
|
hidden = layer(hidden) |
|
|
|
before_out = hidden.view(batch_size, -1) |
|
|
|
return torch.relu(self.dense(before_out)) |