class AMRLMax(torch.nn.Module):
"""
Implements Aggregation for LSTM as described here:
https://www.microsoft.com/en-us/research/publication/amrl-aggregated-memory-for-reinforcement-learning/
def __init__(
self,
input_size: int,