from mlagents.torch_utils import torch from typing import List import math from mlagents.trainers.torch.layers import ( linear_layer, Swish, Initialization, LayerNorm, ) class ConditionalEncoder(torch.nn.Module): def __init__( self, input_size: int, goal_size: int, hidden_size: int, num_layers: int, num_conditional_layers: int, kernel_init: Initialization = Initialization.KaimingHeNormal, kernel_gain: float = 1.0, ): """ ConditionalEncoder module. A fully connected network of which some of the weights are generated by a goal conditioning. Uses the HyperNetwork module to generate the weights of the network. Only the weights of the last "num_conditional_layers" layers will be generated by HyperNetworks, the others will use regular parameters. :param input_size: The size of the input of the encoder :param goal_size: The size of the goal tensor that will condition the encoder :param hidden_size: The number of hidden units in the encoder :param num_layers: The total number of layers of the encoder (both regular and generated by HyperNetwork) :param num_conditional_layers: The number of layers generated with hypernetworks :param kernel_init: The Initialization to use for the weights of the layer :param kernel_gain: The multiplier for the weights of the kernel. """ super().__init__() layers: List[torch.nn.Module] = [] prev_size = input_size + goal_size for i in range(num_layers): if num_layers - i <= num_conditional_layers: # This means layer i is a conditional layer since the conditional # leyers are the last num_conditional_layers layers.append( HyperNetwork(prev_size, hidden_size, goal_size, hidden_size, 2) ) else: layers.append( linear_layer( prev_size, hidden_size, kernel_init=kernel_init, kernel_gain=kernel_gain, ) ) layers.append(Swish()) prev_size = hidden_size self.layers = torch.nn.ModuleList(layers) def forward( self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor ) -> torch.Tensor: # type: ignore activation = torch.cat([input_tensor, goal_tensor], dim=-1) for layer in self.layers: if isinstance(layer, HyperNetwork): activation = layer(activation, goal_tensor) else: activation = layer(activation) return activation class HyperNetwork(torch.nn.Module): def __init__( self, input_size, output_size, hyper_input_size, layer_size, num_layers ): """ Hyper Network module. This module will use the hyper_input tensor to generate the weights of the main network. The main network is a single fully connected layer. :param input_size: The size of the input of the main network :param output_size: The size of the output of the main network :param hyper_input_size: The size of the input of the hypernetwork that will generate the main network. :param layer_size: The number of hidden units in the layers of the hypernetwork :param num_layers: The number of layers of the hypernetwork """ super().__init__() self.input_size = input_size self.output_size = output_size layer_in_size = hyper_input_size layers = [] for _ in range(num_layers): layers.append( linear_layer( layer_in_size, layer_size, kernel_init=Initialization.KaimingHeNormal, kernel_gain=1.0, bias_init=Initialization.Zero, ) ) layers.append(Swish()) layer_in_size = layer_size flat_output = linear_layer( layer_size, input_size * output_size, kernel_init=Initialization.KaimingHeNormal, kernel_gain=0.1, bias_init=Initialization.Zero, ) # Re-initializing the weights of the last layer of the hypernetwork bound = math.sqrt(1 / (layer_size * self.input_size)) flat_output.weight.data.uniform_(-bound, bound) self.hypernet = torch.nn.Sequential(*layers, LayerNorm(), flat_output) # The hypernetwork will not generate the bias of the main network layer self.bias = torch.nn.Parameter(torch.zeros(output_size)) def forward(self, input_activation, hyper_input): output_weights = self.hypernet(hyper_input) output_weights = output_weights.view(-1, self.input_size, self.output_size) result = ( torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1) + self.bias ) return result