|
|
|
|
|
|
import torch |
|
|
|
import abc |
|
|
|
from typing import Tuple |
|
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return lstm |
|
|
|
|
|
|
|
|
|
|
|
class AMRLMax(torch.nn.Module): |
|
|
|
class MemoryModule(torch.nn.Module): |
|
|
|
@abc.abstractproperty |
|
|
|
def memory_size(self) -> int: |
|
|
|
""" |
|
|
|
Size of memory that is required at the start of a sequence. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def forward( |
|
|
|
self, input_tensor: torch.Tensor, memories: torch.Tensor |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
""" |
|
|
|
Pass a sequence to the memory module. |
|
|
|
:input_tensor: Tensor of shape (batch_size, seq_length, size) that represents the input. |
|
|
|
:memories: Tensor of initial memories. |
|
|
|
:return: Tuple of output, final memories. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class LSTM(MemoryModule): |
|
|
|
Implements Aggregation for LSTM as described here: |
|
|
|
https://www.microsoft.com/en-us/research/publication/amrl-aggregated-memory-for-reinforcement-learning/ |
|
|
|
Memory module that implements LSTM. |
|
|
|
hidden_size: int, |
|
|
|
memory_size: int, |
|
|
|
batch_first: bool = True, |
|
|
|
num_post_layers: int = 1, |
|
|
|
self.hidden_size = memory_size // 2 |
|
|
|
hidden_size, |
|
|
|
self.hidden_size, |
|
|
|
batch_first, |
|
|
|
True, |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.layers = [] |
|
|
|
for _ in range(num_post_layers): |
|
|
|
self.layers.append( |
|
|
|
linear_layer( |
|
|
|
hidden_size, |
|
|
|
hidden_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=1.0, |
|
|
|
) |
|
|
|
) |
|
|
|
self.layers.append(Swish()) |
|
|
|
self.seq_layers = torch.nn.Sequential(*self.layers) |
|
|
|
return self.hidden_size // 2 + 2 * self.hidden_size |
|
|
|
return 2 * self.hidden_size |
|
|
|
# memories is 1/2 * hidden_size (accumulant) + hidden_size/2 (h0) + hidden_size/2 (c0) |
|
|
|
acc, h0, c0 = torch.split( |
|
|
|
memories, |
|
|
|
[self.hidden_size // 2, self.hidden_size, self.hidden_size], |
|
|
|
dim=-1, |
|
|
|
) |
|
|
|
h0, c0 = torch.split(memories, self.hidden_size, dim=-1) |
|
|
|
all_c = [] |
|
|
|
m = acc.permute([1, 0, 2]) |
|
|
|
lstm_out, (h0_out, c0_out) = self.lstm(input_tensor, hidden) |
|
|
|
h_half, other_half = torch.split(lstm_out, self.hidden_size // 2, dim=-1) |
|
|
|
for t in range(h_half.shape[1]): |
|
|
|
h_half_subt = h_half[:, t : t + 1, :] |
|
|
|
m = AMRLMax.PassthroughMax.apply(m, h_half_subt) |
|
|
|
all_c.append(m) |
|
|
|
concat_c = torch.cat(all_c, dim=1) |
|
|
|
concat_out = torch.cat([concat_c, other_half], dim=-1) |
|
|
|
full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size])) |
|
|
|
full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size]) |
|
|
|
output_mem = torch.cat([m.permute([1, 0, 2]), h0_out, c0_out], dim=-1) |
|
|
|
return concat_out, output_mem |
|
|
|
|
|
|
|
class PassthroughMax(torch.autograd.Function): |
|
|
|
@staticmethod |
|
|
|
def forward(ctx, tensor1, tensor2): |
|
|
|
return torch.max(tensor1, tensor2) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def backward(ctx, grad_output): |
|
|
|
return grad_output.clone(), grad_output.clone() |
|
|
|
lstm_out, hidden_out = self.lstm(input_tensor, hidden) |
|
|
|
output_mem = torch.cat(hidden_out, dim=-1) |
|
|
|
return lstm_out, output_mem |