|
|
|
|
|
|
return value_outputs |
|
|
|
|
|
|
|
|
|
|
|
class ValueHeadsHyperNetwork(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
num_layers, |
|
|
|
layer_size, |
|
|
|
num_goals, |
|
|
|
stream_names: List[str], |
|
|
|
input_size: int, |
|
|
|
output_size: int = 1, |
|
|
|
): |
|
|
|
class HyperNetwork(nn.Module): |
|
|
|
def __init__(self, input_size, output_size, hyper_input_size, num_layers, layer_size): |
|
|
|
self.stream_names = stream_names |
|
|
|
self._num_goals = num_goals |
|
|
|
self.streams_size = len(stream_names) |
|
|
|
layers = [] |
|
|
|
layers.append( |
|
|
|
linear_layer( |
|
|
|
num_goals, |
|
|
|
layer_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=1.0, |
|
|
|
bias_init=Initialization.Zero, |
|
|
|
) |
|
|
|
) |
|
|
|
layers.append(Swish()) |
|
|
|
layers = [linear_layer( |
|
|
|
hyper_input_size, |
|
|
|
layer_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=1.0, |
|
|
|
bias_init=Initialization.Zero, |
|
|
|
), Swish()] |
|
|
|
for _ in range(num_layers - 1): |
|
|
|
layers.append( |
|
|
|
linear_layer( |
|
|
|
|
|
|
layers.append(Swish()) |
|
|
|
flat_output = linear_layer( |
|
|
|
layer_size, |
|
|
|
input_size * output_size * self.streams_size |
|
|
|
+ self.output_size * self.streams_size, |
|
|
|
input_size * output_size + output_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=0.1, |
|
|
|
bias_init=Initialization.Zero, |
|
|
|
|
|
|
def forward( |
|
|
|
self, hidden: torch.Tensor, goal: torch.Tensor |
|
|
|
) -> Dict[str, torch.Tensor]: |
|
|
|
goal_onehot = torch.nn.functional.one_hot( |
|
|
|
goal[0].long(), self._num_goals |
|
|
|
).float() |
|
|
|
# (b, i * o * streams + o * streams) |
|
|
|
flat_output_weights = self.hypernet(goal_onehot) |
|
|
|
b = hidden.size(0) |
|
|
|
def forward(self, input_activation, hyper_input): |
|
|
|
flat_output_weights = self.hypernet(hyper_input) |
|
|
|
batch_size = input_activation.size(0) |
|
|
|
self.streams_size * self.input_size * self.output_size, |
|
|
|
self.input_size * self.output_size, |
|
|
|
output_weights = torch.reshape( |
|
|
|
output_weights, (self.streams_size, b, self.input_size, self.output_size) |
|
|
|
) |
|
|
|
output_bias = torch.reshape( |
|
|
|
output_bias, (self.streams_size, b, self.output_size) |
|
|
|
) |
|
|
|
output_bias = output_bias.unsqueeze(dim=2) |
|
|
|
|
|
|
|
output_weights = output_weights.view(batch_size, self.input_size, self.output_size) |
|
|
|
output_bias = output_bias.view(batch_size, self.output_size) |
|
|
|
print(output_weights.shape, output_bias.shape, input_activation.shape) |
|
|
|
output = torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1) + output_bias |
|
|
|
print(output.shape) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class ValueHeadsHyperNetwork(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
num_layers, |
|
|
|
layer_size, |
|
|
|
goal_size, |
|
|
|
stream_names: List[str], |
|
|
|
input_size: int, |
|
|
|
output_size: int = 1, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.stream_names = stream_names |
|
|
|
self._num_goals = goal_size |
|
|
|
self.input_size = input_size |
|
|
|
self.output_size = output_size |
|
|
|
self.streams_size = len(stream_names) |
|
|
|
self.hypernetwork = HyperNetwork(input_size, self.output_size * self.streams_size, goal_size, num_layers, layer_size) |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, hidden: torch.Tensor, goal: torch.Tensor |
|
|
|
) -> Dict[str, torch.Tensor]: |
|
|
|
output = self.hypernetwork(hidden, goal) |
|
|
|
for stream_name, out_w, out_b in zip( |
|
|
|
self.stream_names, output_weights, output_bias |
|
|
|
): |
|
|
|
inp_out_w = torch.bmm(hidden.unsqueeze(dim=1), out_w) |
|
|
|
inp_out_w_out_b = inp_out_w + out_b |
|
|
|
value_outputs[stream_name] = inp_out_w_out_b.squeeze() |
|
|
|
output_list = torch.split(output, self.output_size, dim=1) |
|
|
|
for stream_name, output_activation in zip(self.stream_names, output_list): |
|
|
|
value_outputs[stream_name] = output_activation |
|
|
|
return value_outputs |