|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|
|
|
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|
|
|
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|
|
|
@pytest.mark.parametrize("rnn", [False], ids=["no_rnn"]) |
|
|
|
def test_coma_get_value_estimates(dummy_config, rnn, visual, discrete): |
|
|
|
optimizer = create_test_coma_optimizer( |
|
|
|
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|
|
|
|
|
|
max_step_complete=True, |
|
|
|
num_other_agents_in_group=NUM_AGENTS, |
|
|
|
) |
|
|
|
run_out, final_value_out, all_memories = optimizer.get_trajectory_and_baseline_value_estimates( |
|
|
|
trajectory.to_agentbuffer(), trajectory.next_obs, done=False |
|
|
|
value_estimates, baseline_estimates, next_value_estimates = optimizer.get_trajectory_and_baseline_value_estimates( |
|
|
|
trajectory.to_agentbuffer(), |
|
|
|
trajectory.next_obs, |
|
|
|
trajectory.next_group_obs, |
|
|
|
done=False, |
|
|
|
for key, val in run_out.items(): |
|
|
|
for key, val in value_estimates.items(): |
|
|
|
assert type(key) is str |
|
|
|
assert len(val) == 15 |
|
|
|
for key, val in baseline_estimates.items(): |
|
|
|
if all_memories is not None: |
|
|
|
assert len(all_memories) == 15 |
|
|
|
|
|
|
|
# if all_memories is not None: |
|
|
|
# assert len(all_memories) == 15 |
|
|
|
run_out, final_value_out, _ = optimizer.get_trajectory_and_baseline_value_estimates( |
|
|
|
trajectory.to_agentbuffer(), trajectory.next_obs, done=True |
|
|
|
value_estimates, baseline_estimates, next_value_estimates = optimizer.get_trajectory_and_baseline_value_estimates( |
|
|
|
trajectory.to_agentbuffer(), |
|
|
|
trajectory.next_obs, |
|
|
|
trajectory.next_group_obs, |
|
|
|
done=True, |
|
|
|
for key, val in final_value_out.items(): |
|
|
|
for key, val in next_value_estimates.items(): |
|
|
|
run_out, final_value_out, _ = optimizer.get_trajectory_and_baseline_value_estimates( |
|
|
|
trajectory.to_agentbuffer(), trajectory.next_obs, done=False |
|
|
|
value_estimates, baseline_estimates, next_value_estimates = optimizer.get_trajectory_and_baseline_value_estimates( |
|
|
|
trajectory.to_agentbuffer(), |
|
|
|
trajectory.next_obs, |
|
|
|
trajectory.next_group_obs, |
|
|
|
done=False, |
|
|
|
for key, val in final_value_out.items(): |
|
|
|
for key, val in next_value_estimates.items(): |
|
|
|
assert type(key) is str |
|
|
|
assert val != 0.0 |
|
|
|
|
|
|
|