|
|
|
|
|
|
assert entropies.shape == (1, len(dist_list)) |
|
|
|
# Make sure the first action has high probability than the others. |
|
|
|
assert log_probs.flatten()[0] > log_probs.flatten()[1] |
|
|
|
|
|
|
|
|
|
|
|
def test_masked_mean(): |
|
|
|
test_input = torch.tensor([1, 2, 3, 4, 5]) |
|
|
|
masks = torch.ones_like(test_input).bool() |
|
|
|
mean = ModelUtils.masked_mean(test_input, masks=masks) |
|
|
|
assert mean == 3.0 |
|
|
|
|
|
|
|
masks = torch.tensor([False, False, True, True, True]) |
|
|
|
mean = ModelUtils.masked_mean(test_input, masks=masks) |
|
|
|
assert mean == 4.0 |
|
|
|
|
|
|
|
# Make sure it works if all masks are off |
|
|
|
masks = torch.tensor([False, False, False, False, False]) |
|
|
|
mean = ModelUtils.masked_mean(test_input, masks=masks) |
|
|
|
assert mean == 0.0 |