浏览代码

LSTM class

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
d22d2e26
共有 2 个文件被更改,包括 36 次插入53 次删除
  1. 85
      ml-agents/mlagents/trainers/torch/layers.py
  2. 4
      ml-agents/mlagents/trainers/torch/networks.py

85
ml-agents/mlagents/trainers/torch/layers.py


import torch
import abc
from typing import Tuple
from enum import Enum

return lstm
class AMRLMax(torch.nn.Module):
class MemoryModule(torch.nn.Module):
@abc.abstractproperty
def memory_size(self) -> int:
"""
Size of memory that is required at the start of a sequence.
"""
pass
@abc.abstractmethod
def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Pass a sequence to the memory module.
:input_tensor: Tensor of shape (batch_size, seq_length, size) that represents the input.
:memories: Tensor of initial memories.
:return: Tuple of output, final memories.
"""
pass
class LSTM(MemoryModule):
Implements Aggregation for LSTM as described here:
https://www.microsoft.com/en-us/research/publication/amrl-aggregated-memory-for-reinforcement-learning/
Memory module that implements LSTM.
hidden_size: int,
memory_size: int,
batch_first: bool = True,
num_post_layers: int = 1,
self.hidden_size = memory_size // 2
hidden_size,
self.hidden_size,
batch_first,
True,
self.hidden_size = hidden_size
self.layers = []
for _ in range(num_post_layers):
self.layers.append(
linear_layer(
hidden_size,
hidden_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
)
)
self.layers.append(Swish())
self.seq_layers = torch.nn.Sequential(*self.layers)
return self.hidden_size // 2 + 2 * self.hidden_size
return 2 * self.hidden_size
# memories is 1/2 * hidden_size (accumulant) + hidden_size/2 (h0) + hidden_size/2 (c0)
acc, h0, c0 = torch.split(
memories,
[self.hidden_size // 2, self.hidden_size, self.hidden_size],
dim=-1,
)
h0, c0 = torch.split(memories, self.hidden_size, dim=-1)
all_c = []
m = acc.permute([1, 0, 2])
lstm_out, (h0_out, c0_out) = self.lstm(input_tensor, hidden)
h_half, other_half = torch.split(lstm_out, self.hidden_size // 2, dim=-1)
for t in range(h_half.shape[1]):
h_half_subt = h_half[:, t : t + 1, :]
m = AMRLMax.PassthroughMax.apply(m, h_half_subt)
all_c.append(m)
concat_c = torch.cat(all_c, dim=1)
concat_out = torch.cat([concat_c, other_half], dim=-1)
full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size]))
full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size])
output_mem = torch.cat([m.permute([1, 0, 2]), h0_out, c0_out], dim=-1)
return concat_out, output_mem
class PassthroughMax(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor1, tensor2):
return torch.max(tensor1, tensor2)
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone(), grad_output.clone()
lstm_out, hidden_out = self.lstm(input_tensor, hidden)
output_mem = torch.cat(hidden_out, dim=-1)
return lstm_out, output_mem

4
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import AMRLMax
from mlagents.trainers.torch.layers import LSTM
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[

)
if self.use_lstm:
self.lstm = AMRLMax(self.h_size, self.m_size // 2, batch_first=True)
self.lstm = LSTM(self.h_size, self.m_size)
else:
self.lstm = None # type: ignore

正在加载...
取消
保存