浏览代码

Add AMRL layer

/develop/amrl
Ervin Teng 4 年前
当前提交
14a7e29b
共有 1 个文件被更改,包括 78 次插入0 次删除
  1. 78
      ml-agents/mlagents/trainers/torch/layers.py

78
ml-agents/mlagents/trainers/torch/layers.py


lstm_out, hidden_out = self.lstm(input_tensor, hidden)
output_mem = torch.cat(hidden_out, dim=-1)
return lstm_out, output_mem
class AMRLMax(MemoryModule):
"""
Implements Aggregation for LSTM as described here:
https://www.microsoft.com/en-us/research/publication/amrl-aggregated-memory-for-reinforcement-learning/
"""
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
batch_first: bool = True,
forget_bias: float = 1.0,
kernel_init: Initialization = Initialization.XavierGlorotUniform,
bias_init: Initialization = Initialization.Zero,
num_post_layers: int = 1,
):
super().__init__()
self.lstm = lstm_layer(
input_size,
hidden_size,
num_layers,
batch_first,
forget_bias,
kernel_init,
bias_init,
)
self.hidden_size = hidden_size
self.layers = []
for _ in range(num_post_layers):
self.layers.append(
linear_layer(
hidden_size,
hidden_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
)
)
self.layers.append(Swish())
self.seq_layers = torch.nn.Sequential(*self.layers)
@property
def memory_size(self) -> int:
return self.hidden_size // 2 + 2 * self.hidden_size
def forward(self, input_tensor, memories):
# memories is 1/2 * hidden_size (accumulant) + hidden_size/2 (h0) + hidden_size/2 (c0)
acc, h0, c0 = torch.split(
memories,
[self.hidden_size // 2, self.hidden_size, self.hidden_size],
dim=-1,
)
hidden = (h0, c0)
all_c = []
m = acc.permute([1, 0, 2])
lstm_out, (h0_out, c0_out) = self.lstm(input_tensor, hidden)
h_half, other_half = torch.split(lstm_out, self.hidden_size // 2, dim=-1)
for t in range(h_half.shape[1]):
h_half_subt = h_half[:, t : t + 1, :]
m = AMRLMax.PassthroughMax.apply(m, h_half_subt)
all_c.append(m)
concat_c = torch.cat(all_c, dim=1)
concat_out = torch.cat([concat_c, other_half], dim=-1)
full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size]))
full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size])
output_mem = torch.cat([m.permute([1, 0, 2]), h0_out, c0_out], dim=-1)
return concat_out, output_mem
class PassthroughMax(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor1, tensor2):
return torch.max(tensor1, tensor2)
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone(), grad_output.clone()
正在加载...
取消
保存