浏览代码

Adding Hypernetwork modules and unit tests (#5141)

/develop/lex-walker-model
GitHub 3 年前
当前提交
ef3d6e0d
共有 2 个文件被更改,包括 185 次插入0 次删除
  1. 52
      ml-agents/mlagents/trainers/tests/torch/test_conditioning.py
  2. 133
      ml-agents/mlagents/trainers/torch/conditioning.py

52
ml-agents/mlagents/trainers/tests/torch/test_conditioning.py


import pytest
from mlagents.torch_utils import torch
import numpy as np
from mlagents.trainers.torch.layers import linear_layer
from mlagents.trainers.torch.conditioning import ConditionalEncoder
def test_conditional_layer_initialization():
b, input_size, goal_size, h, num_cond_layers, num_normal_layers = 7, 10, 8, 16, 2, 1
conditional_enc = ConditionalEncoder(
input_size, goal_size, h, num_normal_layers + num_cond_layers, num_cond_layers
)
input_tensor = torch.ones(b, input_size)
goal_tensor = torch.ones(b, goal_size)
output = conditional_enc.forward(input_tensor, goal_tensor)
assert output.shape == (b, h)
@pytest.mark.parametrize("num_cond_layers", [1, 2, 3])
def test_predict_with_condition(num_cond_layers):
np.random.seed(1336)
torch.manual_seed(1336)
input_size, goal_size, h, num_normal_layers = 10, 1, 16, 1
conditional_enc = ConditionalEncoder(
input_size, goal_size, h, num_normal_layers + num_cond_layers, num_cond_layers
)
l_layer = linear_layer(h, 1)
optimizer = torch.optim.Adam(
list(conditional_enc.parameters()) + list(l_layer.parameters()), lr=0.001
)
batch_size = 200
for _ in range(300):
input_tensor = torch.rand((batch_size, input_size))
goal_tensor = (torch.rand((batch_size, goal_size)) > 0.5).float()
# If the goal is 1: do the sum of the inputs, else, return 0
target = torch.sum(input_tensor, dim=1, keepdim=True) * goal_tensor
target.detach()
prediction = l_layer(conditional_enc(input_tensor, goal_tensor))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
print(error.item())
optimizer.zero_grad()
error.backward()
optimizer.step()
assert error.item() < 0.02

133
ml-agents/mlagents/trainers/torch/conditioning.py


from mlagents.torch_utils import torch
from typing import List
import math
from mlagents.trainers.torch.layers import (
linear_layer,
Swish,
Initialization,
LayerNorm,
)
class ConditionalEncoder(torch.nn.Module):
def __init__(
self,
input_size: int,
goal_size: int,
hidden_size: int,
num_layers: int,
num_conditional_layers: int,
kernel_init: Initialization = Initialization.KaimingHeNormal,
kernel_gain: float = 1.0,
):
"""
ConditionalEncoder module. A fully connected network of which some of the
weights are generated by a goal conditioning. Uses the HyperNetwork module to
generate the weights of the network. Only the weights of the last
"num_conditional_layers" layers will be generated by HyperNetworks, the others
will use regular parameters.
:param input_size: The size of the input of the encoder
:param goal_size: The size of the goal tensor that will condition the encoder
:param hidden_size: The number of hidden units in the encoder
:param num_layers: The total number of layers of the encoder (both regular and
generated by HyperNetwork)
:param num_conditional_layers: The number of layers generated with hypernetworks
:param kernel_init: The Initialization to use for the weights of the layer
:param kernel_gain: The multiplier for the weights of the kernel.
"""
super().__init__()
layers: List[torch.nn.Module] = []
prev_size = input_size + goal_size
for i in range(num_layers):
if num_layers - i <= num_conditional_layers:
# This means layer i is a conditional layer since the conditional
# leyers are the last num_conditional_layers
layers.append(
HyperNetwork(prev_size, hidden_size, goal_size, hidden_size, 2)
)
else:
layers.append(
linear_layer(
prev_size,
hidden_size,
kernel_init=kernel_init,
kernel_gain=kernel_gain,
)
)
layers.append(Swish())
prev_size = hidden_size
self.layers = torch.nn.ModuleList(layers)
def forward(
self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor
) -> torch.Tensor: # type: ignore
activation = torch.cat([input_tensor, goal_tensor], dim=-1)
for layer in self.layers:
if isinstance(layer, HyperNetwork):
activation = layer(activation, goal_tensor)
else:
activation = layer(activation)
return activation
class HyperNetwork(torch.nn.Module):
def __init__(
self, input_size, output_size, hyper_input_size, layer_size, num_layers
):
"""
Hyper Network module. This module will use the hyper_input tensor to generate
the weights of the main network. The main network is a single fully connected
layer.
:param input_size: The size of the input of the main network
:param output_size: The size of the output of the main network
:param hyper_input_size: The size of the input of the hypernetwork that will
generate the main network.
:param layer_size: The number of hidden units in the layers of the hypernetwork
:param num_layers: The number of layers of the hypernetwork
"""
super().__init__()
self.input_size = input_size
self.output_size = output_size
layer_in_size = hyper_input_size
layers = []
for _ in range(num_layers):
layers.append(
linear_layer(
layer_in_size,
layer_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
bias_init=Initialization.Zero,
)
)
layers.append(Swish())
layer_in_size = layer_size
flat_output = linear_layer(
layer_size,
input_size * output_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
bias_init=Initialization.Zero,
)
# Re-initializing the weights of the last layer of the hypernetwork
bound = math.sqrt(1 / (layer_size * self.input_size))
flat_output.weight.data.uniform_(-bound, bound)
self.hypernet = torch.nn.Sequential(*layers, LayerNorm(), flat_output)
# The hypernetwork will not generate the bias of the main network layer
self.bias = torch.nn.Parameter(torch.zeros(output_size))
def forward(self, input_activation, hyper_input):
output_weights = self.hypernet(hyper_input)
output_weights = output_weights.view(-1, self.input_size, self.output_size)
result = (
torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1)
+ self.bias
)
return result
正在加载...
取消
保存