|
|
|
|
|
|
if m is None: |
|
|
|
m = h_half_subt |
|
|
|
else: |
|
|
|
m = torch.max(m, h_half_subt) |
|
|
|
m = AMRLMax.PassthroughMax.apply(m, h_half_subt) |
|
|
|
all_c.append(m) |
|
|
|
concat_c = torch.cat(all_c, dim=1) |
|
|
|
concat_out = torch.cat([concat_c, other_half], dim=-1) |
|
|
|
|
|
|
|
|
|
|
class PassthroughMax(torch.autograd.Function): |
|
|
|
@staticmethod |
|
|
|
def forward(ctx, tensor1, tensor2): |
|
|
|
return torch.max(tensor1, tensor2) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def backward(ctx, grad_output): |
|
|
|
return grad_output.clone(), grad_output.clone() |