|
|
|
|
|
|
masks = torch.tensor([False, False, False, False, False]) |
|
|
|
mean = ModelUtils.masked_mean(test_input, masks=masks) |
|
|
|
assert mean == 0.0 |
|
|
|
|
|
|
|
# Make sure it works with 2d arrays of shape (mask_length, N) |
|
|
|
test_input = torch.tensor([1, 2, 3, 4, 5]).repeat(2, 1).T |
|
|
|
masks = torch.tensor([False, False, True, True, True]) |
|
|
|
mean = ModelUtils.masked_mean(test_input, masks=masks) |
|
|
|
assert mean == 4.0 |