浏览代码

Use built-in cumulative max

/develop/amrl
Ervin Teng 4 年前
当前提交
9f96a495
共有 1 个文件被更改,包括 10 次插入11 次删除
  1. 21
      ml-agents/mlagents/trainers/torch/layers.py

21
ml-agents/mlagents/trainers/torch/layers.py


def __init__(
self,
input_size: int,
hidden_size: int,
memory_size: int,
num_layers: int = 1,
batch_first: bool = True,
forget_bias: float = 1.0,

):
super().__init__()
self.hidden_size = memory_size // 2
hidden_size,
self.hidden_size,
num_layers,
batch_first,
forget_bias,

self.hidden_size = hidden_size
hidden_size,
hidden_size,
self.hidden_size,
self.hidden_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.0,
)

dim=-1,
)
hidden = (h0, c0)
all_c = []
for t in range(h_half.shape[1]):
h_half_subt = h_half[:, t : t + 1, :]
m = AMRLMax.PassthroughMax.apply(m, h_half_subt)
all_c.append(m)
concat_c = torch.cat(all_c, dim=1)
# Append incoming m to LSTM hidden outs
cat_m = torch.cat([m, h_half], dim=1)
concat_c, _ = torch.cummax(cat_m, dim=1)
# Remove first element, which is just m
concat_c = concat_c[:, 1:, :]
concat_out = torch.cat([concat_c, other_half], dim=-1)
full_out = self.seq_layers(concat_out.reshape([-1, self.hidden_size]))
full_out = full_out.reshape([-1, input_tensor.shape[1], self.hidden_size])

正在加载...
取消
保存