您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
133 行
5.0 KiB
133 行
5.0 KiB
from mlagents.torch_utils import torch
|
|
from typing import List
|
|
import math
|
|
|
|
from mlagents.trainers.torch.layers import (
|
|
linear_layer,
|
|
Swish,
|
|
Initialization,
|
|
LayerNorm,
|
|
)
|
|
|
|
|
|
class ConditionalEncoder(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
goal_size: int,
|
|
hidden_size: int,
|
|
num_layers: int,
|
|
num_conditional_layers: int,
|
|
kernel_init: Initialization = Initialization.KaimingHeNormal,
|
|
kernel_gain: float = 1.0,
|
|
):
|
|
"""
|
|
ConditionalEncoder module. A fully connected network of which some of the
|
|
weights are generated by a goal conditioning. Uses the HyperNetwork module to
|
|
generate the weights of the network. Only the weights of the last
|
|
"num_conditional_layers" layers will be generated by HyperNetworks, the others
|
|
will use regular parameters.
|
|
:param input_size: The size of the input of the encoder
|
|
:param goal_size: The size of the goal tensor that will condition the encoder
|
|
:param hidden_size: The number of hidden units in the encoder
|
|
:param num_layers: The total number of layers of the encoder (both regular and
|
|
generated by HyperNetwork)
|
|
:param num_conditional_layers: The number of layers generated with hypernetworks
|
|
:param kernel_init: The Initialization to use for the weights of the layer
|
|
:param kernel_gain: The multiplier for the weights of the kernel.
|
|
"""
|
|
super().__init__()
|
|
layers: List[torch.nn.Module] = []
|
|
prev_size = input_size + goal_size
|
|
for i in range(num_layers):
|
|
if num_layers - i <= num_conditional_layers:
|
|
# This means layer i is a conditional layer since the conditional
|
|
# leyers are the last num_conditional_layers
|
|
layers.append(
|
|
HyperNetwork(prev_size, hidden_size, goal_size, hidden_size, 2)
|
|
)
|
|
else:
|
|
layers.append(
|
|
linear_layer(
|
|
prev_size,
|
|
hidden_size,
|
|
kernel_init=kernel_init,
|
|
kernel_gain=kernel_gain,
|
|
)
|
|
)
|
|
layers.append(Swish())
|
|
prev_size = hidden_size
|
|
self.layers = torch.nn.ModuleList(layers)
|
|
|
|
def forward(
|
|
self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor
|
|
) -> torch.Tensor: # type: ignore
|
|
activation = torch.cat([input_tensor, goal_tensor], dim=-1)
|
|
for layer in self.layers:
|
|
if isinstance(layer, HyperNetwork):
|
|
activation = layer(activation, goal_tensor)
|
|
else:
|
|
activation = layer(activation)
|
|
return activation
|
|
|
|
|
|
class HyperNetwork(torch.nn.Module):
|
|
def __init__(
|
|
self, input_size, output_size, hyper_input_size, layer_size, num_layers
|
|
):
|
|
"""
|
|
Hyper Network module. This module will use the hyper_input tensor to generate
|
|
the weights of the main network. The main network is a single fully connected
|
|
layer.
|
|
:param input_size: The size of the input of the main network
|
|
:param output_size: The size of the output of the main network
|
|
:param hyper_input_size: The size of the input of the hypernetwork that will
|
|
generate the main network.
|
|
:param layer_size: The number of hidden units in the layers of the hypernetwork
|
|
:param num_layers: The number of layers of the hypernetwork
|
|
"""
|
|
super().__init__()
|
|
self.input_size = input_size
|
|
self.output_size = output_size
|
|
|
|
layer_in_size = hyper_input_size
|
|
layers = []
|
|
for _ in range(num_layers):
|
|
layers.append(
|
|
linear_layer(
|
|
layer_in_size,
|
|
layer_size,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=1.0,
|
|
bias_init=Initialization.Zero,
|
|
)
|
|
)
|
|
layers.append(Swish())
|
|
layer_in_size = layer_size
|
|
flat_output = linear_layer(
|
|
layer_size,
|
|
input_size * output_size,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=0.1,
|
|
bias_init=Initialization.Zero,
|
|
)
|
|
|
|
# Re-initializing the weights of the last layer of the hypernetwork
|
|
bound = math.sqrt(1 / (layer_size * self.input_size))
|
|
flat_output.weight.data.uniform_(-bound, bound)
|
|
|
|
self.hypernet = torch.nn.Sequential(*layers, LayerNorm(), flat_output)
|
|
|
|
# The hypernetwork will not generate the bias of the main network layer
|
|
self.bias = torch.nn.Parameter(torch.zeros(output_size))
|
|
|
|
def forward(self, input_activation, hyper_input):
|
|
output_weights = self.hypernet(hyper_input)
|
|
|
|
output_weights = output_weights.view(-1, self.input_size, self.output_size)
|
|
|
|
result = (
|
|
torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1)
|
|
+ self.bias
|
|
)
|
|
return result
|