浏览代码

Passthrough max

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
46f3a9b9
共有 1 个文件被更改,包括 10 次插入1 次删除
  1. 11
      ml-agents/mlagents/trainers/torch/layers.py

11
ml-agents/mlagents/trainers/torch/layers.py


if m is None:
m = h_half_subt
else:
m = torch.max(m, h_half_subt)
m = AMRLMax.PassthroughMax.apply(m, h_half_subt)
all_c.append(m)
concat_c = torch.cat(all_c, dim=1)
concat_out = torch.cat([concat_c, other_half], dim=-1)

class PassthroughMax(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor1, tensor2):
return torch.max(tensor1, tensor2)
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone(), grad_output.clone()
正在加载...
取消
保存