浏览代码

get value estimate test

/develop/action-slice
Andrew Cohen 4 年前
当前提交
43955c5b
共有 1 个文件被更改,包括 25 次插入12 次删除
  1. 37
      ml-agents/mlagents/trainers/tests/torch/test_coma.py

37
ml-agents/mlagents/trainers/tests/torch/test_coma.py


@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
@pytest.mark.parametrize("rnn", [False], ids=["no_rnn"])
def test_coma_get_value_estimates(dummy_config, rnn, visual, discrete):
optimizer = create_test_coma_optimizer(
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual

max_step_complete=True,
num_other_agents_in_group=NUM_AGENTS,
)
run_out, final_value_out, all_memories = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(), trajectory.next_obs, done=False
value_estimates, baseline_estimates, next_value_estimates = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(),
trajectory.next_obs,
trajectory.next_group_obs,
done=False,
for key, val in run_out.items():
for key, val in value_estimates.items():
assert type(key) is str
assert len(val) == 15
for key, val in baseline_estimates.items():
if all_memories is not None:
assert len(all_memories) == 15
# if all_memories is not None:
# assert len(all_memories) == 15
run_out, final_value_out, _ = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(), trajectory.next_obs, done=True
value_estimates, baseline_estimates, next_value_estimates = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(),
trajectory.next_obs,
trajectory.next_group_obs,
done=True,
for key, val in final_value_out.items():
for key, val in next_value_estimates.items():
run_out, final_value_out, _ = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(), trajectory.next_obs, done=False
value_estimates, baseline_estimates, next_value_estimates = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(),
trajectory.next_obs,
trajectory.next_group_obs,
done=False,
for key, val in final_value_out.items():
for key, val in next_value_estimates.items():
assert type(key) is str
assert val != 0.0

正在加载...
取消
保存