GitHub
4 年前
当前提交
ef3d6e0d
共有 2 个文件被更改,包括 185 次插入 和 0 次删除
-
52ml-agents/mlagents/trainers/tests/torch/test_conditioning.py
-
133ml-agents/mlagents/trainers/torch/conditioning.py
|
|||
import pytest |
|||
from mlagents.torch_utils import torch |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.torch.layers import linear_layer |
|||
from mlagents.trainers.torch.conditioning import ConditionalEncoder |
|||
|
|||
|
|||
def test_conditional_layer_initialization(): |
|||
b, input_size, goal_size, h, num_cond_layers, num_normal_layers = 7, 10, 8, 16, 2, 1 |
|||
conditional_enc = ConditionalEncoder( |
|||
input_size, goal_size, h, num_normal_layers + num_cond_layers, num_cond_layers |
|||
) |
|||
|
|||
input_tensor = torch.ones(b, input_size) |
|||
goal_tensor = torch.ones(b, goal_size) |
|||
|
|||
output = conditional_enc.forward(input_tensor, goal_tensor) |
|||
|
|||
assert output.shape == (b, h) |
|||
|
|||
|
|||
@pytest.mark.parametrize("num_cond_layers", [1, 2, 3]) |
|||
def test_predict_with_condition(num_cond_layers): |
|||
np.random.seed(1336) |
|||
torch.manual_seed(1336) |
|||
input_size, goal_size, h, num_normal_layers = 10, 1, 16, 1 |
|||
|
|||
conditional_enc = ConditionalEncoder( |
|||
input_size, goal_size, h, num_normal_layers + num_cond_layers, num_cond_layers |
|||
) |
|||
l_layer = linear_layer(h, 1) |
|||
|
|||
optimizer = torch.optim.Adam( |
|||
list(conditional_enc.parameters()) + list(l_layer.parameters()), lr=0.001 |
|||
) |
|||
batch_size = 200 |
|||
for _ in range(300): |
|||
input_tensor = torch.rand((batch_size, input_size)) |
|||
goal_tensor = (torch.rand((batch_size, goal_size)) > 0.5).float() |
|||
# If the goal is 1: do the sum of the inputs, else, return 0 |
|||
target = torch.sum(input_tensor, dim=1, keepdim=True) * goal_tensor |
|||
target.detach() |
|||
prediction = l_layer(conditional_enc(input_tensor, goal_tensor)) |
|||
error = torch.mean((prediction - target) ** 2, dim=1) |
|||
error = torch.mean(error) / 2 |
|||
|
|||
print(error.item()) |
|||
optimizer.zero_grad() |
|||
error.backward() |
|||
optimizer.step() |
|||
assert error.item() < 0.02 |
|
|||
from mlagents.torch_utils import torch |
|||
from typing import List |
|||
import math |
|||
|
|||
from mlagents.trainers.torch.layers import ( |
|||
linear_layer, |
|||
Swish, |
|||
Initialization, |
|||
LayerNorm, |
|||
) |
|||
|
|||
|
|||
class ConditionalEncoder(torch.nn.Module): |
|||
def __init__( |
|||
self, |
|||
input_size: int, |
|||
goal_size: int, |
|||
hidden_size: int, |
|||
num_layers: int, |
|||
num_conditional_layers: int, |
|||
kernel_init: Initialization = Initialization.KaimingHeNormal, |
|||
kernel_gain: float = 1.0, |
|||
): |
|||
""" |
|||
ConditionalEncoder module. A fully connected network of which some of the |
|||
weights are generated by a goal conditioning. Uses the HyperNetwork module to |
|||
generate the weights of the network. Only the weights of the last |
|||
"num_conditional_layers" layers will be generated by HyperNetworks, the others |
|||
will use regular parameters. |
|||
:param input_size: The size of the input of the encoder |
|||
:param goal_size: The size of the goal tensor that will condition the encoder |
|||
:param hidden_size: The number of hidden units in the encoder |
|||
:param num_layers: The total number of layers of the encoder (both regular and |
|||
generated by HyperNetwork) |
|||
:param num_conditional_layers: The number of layers generated with hypernetworks |
|||
:param kernel_init: The Initialization to use for the weights of the layer |
|||
:param kernel_gain: The multiplier for the weights of the kernel. |
|||
""" |
|||
super().__init__() |
|||
layers: List[torch.nn.Module] = [] |
|||
prev_size = input_size + goal_size |
|||
for i in range(num_layers): |
|||
if num_layers - i <= num_conditional_layers: |
|||
# This means layer i is a conditional layer since the conditional |
|||
# leyers are the last num_conditional_layers |
|||
layers.append( |
|||
HyperNetwork(prev_size, hidden_size, goal_size, hidden_size, 2) |
|||
) |
|||
else: |
|||
layers.append( |
|||
linear_layer( |
|||
prev_size, |
|||
hidden_size, |
|||
kernel_init=kernel_init, |
|||
kernel_gain=kernel_gain, |
|||
) |
|||
) |
|||
layers.append(Swish()) |
|||
prev_size = hidden_size |
|||
self.layers = torch.nn.ModuleList(layers) |
|||
|
|||
def forward( |
|||
self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor |
|||
) -> torch.Tensor: # type: ignore |
|||
activation = torch.cat([input_tensor, goal_tensor], dim=-1) |
|||
for layer in self.layers: |
|||
if isinstance(layer, HyperNetwork): |
|||
activation = layer(activation, goal_tensor) |
|||
else: |
|||
activation = layer(activation) |
|||
return activation |
|||
|
|||
|
|||
class HyperNetwork(torch.nn.Module): |
|||
def __init__( |
|||
self, input_size, output_size, hyper_input_size, layer_size, num_layers |
|||
): |
|||
""" |
|||
Hyper Network module. This module will use the hyper_input tensor to generate |
|||
the weights of the main network. The main network is a single fully connected |
|||
layer. |
|||
:param input_size: The size of the input of the main network |
|||
:param output_size: The size of the output of the main network |
|||
:param hyper_input_size: The size of the input of the hypernetwork that will |
|||
generate the main network. |
|||
:param layer_size: The number of hidden units in the layers of the hypernetwork |
|||
:param num_layers: The number of layers of the hypernetwork |
|||
""" |
|||
super().__init__() |
|||
self.input_size = input_size |
|||
self.output_size = output_size |
|||
|
|||
layer_in_size = hyper_input_size |
|||
layers = [] |
|||
for _ in range(num_layers): |
|||
layers.append( |
|||
linear_layer( |
|||
layer_in_size, |
|||
layer_size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=1.0, |
|||
bias_init=Initialization.Zero, |
|||
) |
|||
) |
|||
layers.append(Swish()) |
|||
layer_in_size = layer_size |
|||
flat_output = linear_layer( |
|||
layer_size, |
|||
input_size * output_size, |
|||
kernel_init=Initialization.KaimingHeNormal, |
|||
kernel_gain=0.1, |
|||
bias_init=Initialization.Zero, |
|||
) |
|||
|
|||
# Re-initializing the weights of the last layer of the hypernetwork |
|||
bound = math.sqrt(1 / (layer_size * self.input_size)) |
|||
flat_output.weight.data.uniform_(-bound, bound) |
|||
|
|||
self.hypernet = torch.nn.Sequential(*layers, LayerNorm(), flat_output) |
|||
|
|||
# The hypernetwork will not generate the bias of the main network layer |
|||
self.bias = torch.nn.Parameter(torch.zeros(output_size)) |
|||
|
|||
def forward(self, input_activation, hyper_input): |
|||
output_weights = self.hypernet(hyper_input) |
|||
|
|||
output_weights = output_weights.view(-1, self.input_size, self.output_size) |
|||
|
|||
result = ( |
|||
torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1) |
|||
+ self.bias |
|||
) |
|||
return result |
撰写
预览
正在加载...
取消
保存
Reference in new issue