浏览代码

Faster implementation

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
b44b5a24
共有 1 个文件被更改,包括 14 次插入11 次删除
  1. 25
      ml-agents/mlagents/trainers/torch/layers.py

25
ml-agents/mlagents/trainers/torch/layers.py


for _ in range(num_post_layers):
self.layers.append(
linear_layer(
input_size,
hidden_size,
hidden_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,

def forward(self, input_tensor, h0_c0):
hidden = h0_c0
all_out = []
all_c = []
for t in range(input_tensor.shape[1]):
out, hidden = self.lstm(input_tensor[:, t : t + 1, :], hidden)
h_half, other_half = torch.split(out, self.hidden_size // 2, dim=-1)
lstm_out, hidden = self.lstm(input_tensor, hidden)
h_half, other_half = torch.split(lstm_out, self.hidden_size // 2, dim=-1)
for t in range(h_half.shape[1]):
h_half_subt = h_half[:, t : t + 1, :]
m = h_half
m = h_half_subt
m = torch.max(m, h_half)
out = torch.cat([m, other_half])
all_out.append(out)
full_out = self.seq_layers(torch.cat(all_out, dim=1))
return full_out, hidden
m = torch.max(m, h_half_subt)
all_c.append(m)
concat_c = torch.cat(all_c, dim=1)
concat_out = torch.cat([concat_c, other_half], dim=-1)
full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size]))
full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size])
return concat_out, hidden
正在加载...
取消
保存