|
|
|
|
|
|
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy( |
|
|
|
action_list, dist_list |
|
|
|
) |
|
|
|
assert log_probs.shape == (1, 2, 2) |
|
|
|
for lp in log_probs: |
|
|
|
assert lp.shape == (1, 2) |
|
|
|
assert all_probs is None |
|
|
|
assert all_probs == [] |
|
|
|
for log_prob in log_probs.flatten(): |
|
|
|
for log_prob in log_probs: |
|
|
|
assert log_prob == pytest.approx(-0.919, abs=0.01) |
|
|
|
for lp in log_prob.flatten(): |
|
|
|
assert lp == pytest.approx(-0.919, abs=0.01) |
|
|
|
|
|
|
|
for ent in entropies.flatten(): |
|
|
|
# entropy of standard normal at 0 |
|
|
|
|
|
|
log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy( |
|
|
|
action_list, dist_list |
|
|
|
) |
|
|
|
assert all_probs.shape == (1, len(dist_list * act_size)) |
|
|
|
for all_prob in all_probs: |
|
|
|
assert all_prob.shape == (1, act_size) |
|
|
|
assert log_probs.flatten()[0] > log_probs.flatten()[1] |
|
|
|
assert log_probs[0] > log_probs[1] |
|
|
|
|
|
|
|
|
|
|
|
def test_masked_mean(): |
|
|
|