|
|
|
|
|
|
if len(memories) > 0: |
|
|
|
memories = torch.stack(memories).unsqueeze(0) |
|
|
|
|
|
|
|
log_probs, entropy, values = policy.evaluate_actions( |
|
|
|
log_probs, entropy = policy.evaluate_actions( |
|
|
|
tensor_obs, |
|
|
|
masks=act_masks, |
|
|
|
actions=agent_action, |
|
|
|
|
|
|
|
|
|
|
assert log_probs.flatten().shape == (64, _size) |
|
|
|
assert entropy.shape == (64,) |
|
|
|
for val in values.values(): |
|
|
|
assert val.shape == (64,) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|
|
|