您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
298 行
11 KiB
298 行
11 KiB
from typing import Tuple, Optional, Union
|
|
|
|
from mlagents.trainers.torch.layers import linear_layer, Initialization, Swish
|
|
|
|
from mlagents.torch_utils import torch, nn
|
|
from mlagents.trainers.torch.model_serialization import exporting_to_onnx
|
|
|
|
|
|
class Normalizer(nn.Module):
|
|
def __init__(self, vec_obs_size: int):
|
|
super().__init__()
|
|
self.register_buffer("normalization_steps", torch.tensor(1))
|
|
self.register_buffer("running_mean", torch.zeros(vec_obs_size))
|
|
self.register_buffer("running_variance", torch.ones(vec_obs_size))
|
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
normalized_state = torch.clamp(
|
|
(inputs - self.running_mean)
|
|
/ torch.sqrt(self.running_variance / self.normalization_steps),
|
|
-5,
|
|
5,
|
|
)
|
|
return normalized_state
|
|
|
|
def update(self, vector_input: torch.Tensor) -> None:
|
|
with torch.no_grad():
|
|
steps_increment = vector_input.size()[0]
|
|
total_new_steps = self.normalization_steps + steps_increment
|
|
|
|
input_to_old_mean = vector_input - self.running_mean
|
|
new_mean: torch.Tensor = self.running_mean + (
|
|
input_to_old_mean / total_new_steps
|
|
).sum(0)
|
|
|
|
input_to_new_mean = vector_input - new_mean
|
|
new_variance = self.running_variance + (
|
|
input_to_new_mean * input_to_old_mean
|
|
).sum(0)
|
|
# Update references. This is much faster than in-place data update.
|
|
self.running_mean: torch.Tensor = new_mean
|
|
self.running_variance: torch.Tensor = new_variance
|
|
self.normalization_steps: torch.Tensor = total_new_steps
|
|
|
|
def copy_from(self, other_normalizer: "Normalizer") -> None:
|
|
self.normalization_steps.data.copy_(other_normalizer.normalization_steps.data)
|
|
self.running_mean.data.copy_(other_normalizer.running_mean.data)
|
|
self.running_variance.copy_(other_normalizer.running_variance.data)
|
|
|
|
|
|
def conv_output_shape(
|
|
h_w: Tuple[int, int],
|
|
kernel_size: Union[int, Tuple[int, int]] = 1,
|
|
stride: int = 1,
|
|
padding: int = 0,
|
|
dilation: int = 1,
|
|
) -> Tuple[int, int]:
|
|
"""
|
|
Calculates the output shape (height and width) of the output of a convolution layer.
|
|
kernel_size, stride, padding and dilation correspond to the inputs of the
|
|
torch.nn.Conv2d layer (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)
|
|
:param h_w: The height and width of the input.
|
|
:param kernel_size: The size of the kernel of the convolution (can be an int or a
|
|
tuple [width, height])
|
|
:param stride: The stride of the convolution
|
|
:param padding: The padding of the convolution
|
|
:param dilation: The dilation of the convolution
|
|
"""
|
|
from math import floor
|
|
|
|
if not isinstance(kernel_size, tuple):
|
|
kernel_size = (int(kernel_size), int(kernel_size))
|
|
h = floor(
|
|
((h_w[0] + (2 * padding) - (dilation * (kernel_size[0] - 1)) - 1) / stride) + 1
|
|
)
|
|
w = floor(
|
|
((h_w[1] + (2 * padding) - (dilation * (kernel_size[1] - 1)) - 1) / stride) + 1
|
|
)
|
|
return h, w
|
|
|
|
|
|
def pool_out_shape(h_w: Tuple[int, int], kernel_size: int) -> Tuple[int, int]:
|
|
"""
|
|
Calculates the output shape (height and width) of the output of a max pooling layer.
|
|
kernel_size corresponds to the inputs of the
|
|
torch.nn.MaxPool2d layer (https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html)
|
|
:param kernel_size: The size of the kernel of the convolution
|
|
"""
|
|
height = (h_w[0] - kernel_size) // 2 + 1
|
|
width = (h_w[1] - kernel_size) // 2 + 1
|
|
return height, width
|
|
|
|
|
|
class VectorInput(nn.Module):
|
|
def __init__(self, input_size: int, normalize: bool = False):
|
|
super().__init__()
|
|
self.normalizer: Optional[Normalizer] = None
|
|
if normalize:
|
|
self.normalizer = Normalizer(input_size)
|
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
if self.normalizer is not None:
|
|
inputs = self.normalizer(inputs)
|
|
return inputs
|
|
|
|
def copy_normalization(self, other_input: "VectorInput") -> None:
|
|
if self.normalizer is not None and other_input.normalizer is not None:
|
|
self.normalizer.copy_from(other_input.normalizer)
|
|
|
|
def update_normalization(self, inputs: torch.Tensor) -> None:
|
|
if self.normalizer is not None:
|
|
self.normalizer.update(inputs)
|
|
|
|
|
|
class FullyConnectedVisualEncoder(nn.Module):
|
|
def __init__(
|
|
self, height: int, width: int, initial_channels: int, output_size: int
|
|
):
|
|
super().__init__()
|
|
self.output_size = output_size
|
|
self.input_size = height * width * initial_channels
|
|
self.dense = nn.Sequential(
|
|
linear_layer(
|
|
self.input_size,
|
|
self.output_size,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=1.41, # Use ReLU gain
|
|
),
|
|
nn.LeakyReLU(),
|
|
)
|
|
|
|
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
|
|
if not exporting_to_onnx.is_exporting():
|
|
visual_obs = visual_obs.permute([0, 3, 1, 2])
|
|
hidden = visual_obs.reshape(-1, self.input_size)
|
|
return self.dense(hidden)
|
|
|
|
|
|
class SmallVisualEncoder(nn.Module):
|
|
"""
|
|
CNN architecture used by King in their Candy Crush predictor
|
|
https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning
|
|
"""
|
|
|
|
def __init__(
|
|
self, height: int, width: int, initial_channels: int, output_size: int
|
|
):
|
|
super().__init__()
|
|
self.h_size = output_size
|
|
conv_1_hw = conv_output_shape((height, width), 3, 1)
|
|
conv_2_hw = conv_output_shape(conv_1_hw, 3, 1)
|
|
self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 144
|
|
|
|
self.conv_layers = nn.Sequential(
|
|
nn.Conv2d(initial_channels, 35, [3, 3], [1, 1]),
|
|
nn.LeakyReLU(),
|
|
nn.Conv2d(35, 144, [3, 3], [1, 1]),
|
|
nn.LeakyReLU(),
|
|
)
|
|
self.dense = nn.Sequential(
|
|
linear_layer(
|
|
self.final_flat,
|
|
self.h_size,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=1.41, # Use ReLU gain
|
|
),
|
|
nn.LeakyReLU(),
|
|
)
|
|
|
|
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
|
|
if not exporting_to_onnx.is_exporting():
|
|
visual_obs = visual_obs.permute([0, 3, 1, 2])
|
|
hidden = self.conv_layers(visual_obs)
|
|
hidden = hidden.reshape(-1, self.final_flat)
|
|
return self.dense(hidden)
|
|
|
|
|
|
class SimpleVisualEncoder(nn.Module):
|
|
def __init__(
|
|
self, height: int, width: int, initial_channels: int, output_size: int
|
|
):
|
|
super().__init__()
|
|
self.h_size = output_size
|
|
conv_1_hw = conv_output_shape((height, width), 8, 4)
|
|
conv_2_hw = conv_output_shape(conv_1_hw, 4, 2)
|
|
self.final_flat = conv_2_hw[0] * conv_2_hw[1] * 32
|
|
|
|
self.conv_layers = nn.Sequential(
|
|
nn.Conv2d(initial_channels, 16, [8, 8], [4, 4]),
|
|
nn.LeakyReLU(),
|
|
nn.Conv2d(16, 32, [4, 4], [2, 2]),
|
|
nn.LeakyReLU(),
|
|
)
|
|
self.dense = nn.Sequential(
|
|
linear_layer(
|
|
self.final_flat,
|
|
self.h_size,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=1.41, # Use ReLU gain
|
|
),
|
|
nn.LeakyReLU(),
|
|
)
|
|
|
|
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
|
|
if not exporting_to_onnx.is_exporting():
|
|
visual_obs = visual_obs.permute([0, 3, 1, 2])
|
|
hidden = self.conv_layers(visual_obs)
|
|
hidden = hidden.reshape(-1, self.final_flat)
|
|
return self.dense(hidden)
|
|
|
|
|
|
class NatureVisualEncoder(nn.Module):
|
|
def __init__(
|
|
self, height: int, width: int, initial_channels: int, output_size: int
|
|
):
|
|
super().__init__()
|
|
self.h_size = output_size
|
|
conv_1_hw = conv_output_shape((height, width), 8, 4)
|
|
conv_2_hw = conv_output_shape(conv_1_hw, 4, 2)
|
|
conv_3_hw = conv_output_shape(conv_2_hw, 3, 1)
|
|
self.final_flat = conv_3_hw[0] * conv_3_hw[1] * 64
|
|
|
|
self.conv_layers = nn.Sequential(
|
|
nn.Conv2d(initial_channels, 32, [8, 8], [4, 4]),
|
|
nn.LeakyReLU(),
|
|
nn.Conv2d(32, 64, [4, 4], [2, 2]),
|
|
nn.LeakyReLU(),
|
|
nn.Conv2d(64, 64, [3, 3], [1, 1]),
|
|
nn.LeakyReLU(),
|
|
)
|
|
self.dense = nn.Sequential(
|
|
linear_layer(
|
|
self.final_flat,
|
|
self.h_size,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=1.41, # Use ReLU gain
|
|
),
|
|
nn.LeakyReLU(),
|
|
)
|
|
|
|
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
|
|
if not exporting_to_onnx.is_exporting():
|
|
visual_obs = visual_obs.permute([0, 3, 1, 2])
|
|
hidden = self.conv_layers(visual_obs)
|
|
hidden = hidden.reshape([-1, self.final_flat])
|
|
return self.dense(hidden)
|
|
|
|
|
|
class ResNetBlock(nn.Module):
|
|
def __init__(self, channel: int):
|
|
"""
|
|
Creates a ResNet Block.
|
|
:param channel: The number of channels in the input (and output) tensors of the
|
|
convolutions
|
|
"""
|
|
super().__init__()
|
|
self.layers = nn.Sequential(
|
|
Swish(),
|
|
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
|
|
Swish(),
|
|
nn.Conv2d(channel, channel, [3, 3], [1, 1], padding=1),
|
|
)
|
|
|
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
return input_tensor + self.layers(input_tensor)
|
|
|
|
|
|
class ResNetVisualEncoder(nn.Module):
|
|
def __init__(
|
|
self, height: int, width: int, initial_channels: int, output_size: int
|
|
):
|
|
super().__init__()
|
|
n_channels = [16, 32, 32] # channel for each stack
|
|
n_blocks = 2 # number of residual blocks
|
|
layers = []
|
|
last_channel = initial_channels
|
|
for _, channel in enumerate(n_channels):
|
|
layers.append(nn.Conv2d(last_channel, channel, [3, 3], [1, 1], padding=1))
|
|
layers.append(nn.MaxPool2d([3, 3], [2, 2]))
|
|
height, width = pool_out_shape((height, width), 3)
|
|
for _ in range(n_blocks):
|
|
layers.append(ResNetBlock(channel))
|
|
last_channel = channel
|
|
layers.append(Swish())
|
|
self.dense = linear_layer(
|
|
n_channels[-1] * height * width,
|
|
output_size,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=1.41, # Use ReLU gain
|
|
)
|
|
self.sequential = nn.Sequential(*layers)
|
|
|
|
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
|
|
if not exporting_to_onnx.is_exporting():
|
|
visual_obs = visual_obs.permute([0, 3, 1, 2])
|
|
batch_size = visual_obs.shape[0]
|
|
hidden = self.sequential(visual_obs)
|
|
before_out = hidden.reshape(batch_size, -1)
|
|
return torch.relu(self.dense(before_out))
|