浏览代码

Cleanup, add test

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
d63aacd0
共有 2 个文件被更改,包括 24 次插入1 次删除
  1. 19
      ml-agents/mlagents/trainers/tests/torch/test_layers.py
  2. 6
      ml-agents/mlagents/trainers/torch/layers.py

19
ml-agents/mlagents/trainers/tests/torch/test_layers.py


linear_layer,
lstm_layer,
Initialization,
LSTM,
)

assert torch.all(
torch.eq(param.data[4:8], torch.ones_like(param.data[4:8]))
)
def test_lstm_class():
torch.manual_seed(0)
input_size = 12
memory_size = 64
batch_size = 8
seq_len = 16
lstm = LSTM(input_size, memory_size)
assert lstm.memory_size == memory_size
sample_input = torch.ones((batch_size, seq_len, input_size))
sample_memories = torch.ones((1, batch_size, memory_size))
out, mem = lstm(sample_input, sample_memories)
# Hidden size should be half of memory_size
assert out.shape == (batch_size, seq_len, memory_size // 2)
assert mem.shape == (1, batch_size, memory_size)

6
ml-agents/mlagents/trainers/torch/layers.py


bias_init: Initialization = Initialization.Zero,
):
super().__init__()
# We set hidden size to half of memory_size since the initial memory
# will be divided between the hidden state and initial cell state.
self.hidden_size = memory_size // 2
self.lstm = lstm_layer(
input_size,

def memory_size(self) -> int:
return 2 * self.hidden_size
def forward(self, input_tensor, memories):
def forward(
self, input_tensor: torch.Tensor, memories: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
h0, c0 = torch.split(memories, self.hidden_size, dim=-1)
hidden = (h0, c0)
lstm_out, hidden_out = self.lstm(input_tensor, hidden)
正在加载...
取消
保存