|
|
|
|
|
|
|
|
|
|
run_out = policy.evaluate(decision_step, list(decision_step.agent_id)) |
|
|
|
if discrete: |
|
|
|
run_out["action"]["discrete_action"].shape == ( |
|
|
|
NUM_AGENTS, |
|
|
|
len(DISCRETE_ACTION_SPACE), |
|
|
|
) |
|
|
|
run_out["action"].discrete.shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE)) |
|
|
|
assert run_out["action"]["continuous_action"].shape == ( |
|
|
|
NUM_AGENTS, |
|
|
|
VECTOR_ACTION_SPACE, |
|
|
|
) |
|
|
|
assert run_out["action"].continuous.shape == (NUM_AGENTS, VECTOR_ACTION_SPACE) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|
|
|