浏览代码

Add hypernetwork

/hypernetwork
Arthur Juliani 4 年前
当前提交
da0c8b9d
共有 1 个文件被更改,包括 50 次插入0 次删除
  1. 50
      ml-agents/mlagents/trainers/torch/layers.py

50
ml-agents/mlagents/trainers/torch/layers.py


lstm_out, hidden_out = self.lstm(input_tensor, hidden)
output_mem = torch.cat(hidden_out, dim=-1)
return lstm_out, output_mem
class HyperNetwork(torch.nn.Module):
def __init__(
self, input_size, output_size, hyper_input_size, num_layers, layer_size
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
layer_in_size = hyper_input_size
layers = []
for _ in range(num_layers):
layers.append(
linear_layer(
layer_in_size,
layer_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
bias_init=Initialization.Zero,
)
)
layers.append(Swish())
layer_in_size = layer_size
flat_output = linear_layer(
layer_size,
input_size * output_size + output_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=0.1,
bias_init=Initialization.Zero,
)
self.hypernet = torch.nn.Sequential(*layers, flat_output)
def forward(self, input_activation, hyper_input):
flat_output_weights = self.hypernet(hyper_input)
batch_size = input_activation.size(0)
output_weights, output_bias = torch.split(
flat_output_weights, self.input_size * self.output_size, dim=-1
)
output_weights = output_weights.view(
batch_size, self.input_size, self.output_size
)
output_bias = output_bias.view(batch_size, self.output_size)
output = (
torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1)
+ output_bias
)
return output
正在加载...
取消
保存