浏览代码

torch coma tests: lstm, cur, gail

/develop/coma2/fixgroup
Andrew Cohen 4 年前
当前提交
853b44d5
共有 1 个文件被更改,包括 36 次插入17 次删除
  1. 53
      ml-agents/mlagents/trainers/tests/torch/test_coma.py

53
ml-agents/mlagents/trainers/tests/torch/test_coma.py


trainer_settings = attr.evolve(dummy_config)
trainer_settings.reward_signals = {
RewardSignalType.EXTRINSIC: ExtrinsicSettings(
strength=1.0, gamma=0.99, add_groupmate_rewards=True
)
RewardSignalType.EXTRINSIC: ExtrinsicSettings(strength=1.0, gamma=0.99)
}
trainer_settings.network_settings.memory = (

BufferKey.ENVIRONMENT_REWARDS,
[
BufferKey.ADVANTAGES,
RewardSignalUtil.returns_key("group"),
RewardSignalUtil.value_estimates_key("group"),
RewardSignalUtil.baseline_estimates_key("group"),
RewardSignalUtil.returns_key("extrinsic"),
RewardSignalUtil.value_estimates_key("extrinsic"),
RewardSignalUtil.baseline_estimates_key("extrinsic"),
copy_buffer_fields(update_buffer, BufferKey.MEMORY, [BufferKey.CRITIC_MEMORY])
copy_buffer_fields(
update_buffer,
BufferKey.MEMORY,
[BufferKey.CRITIC_MEMORY, BufferKey.BASELINE_MEMORY],
)
return_stats = optimizer.update(
update_buffer,

@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [False], ids=["no_rnn"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_coma_get_value_estimates(dummy_config, rnn, visual, discrete):
optimizer = create_test_coma_optimizer(
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual

(
value_estimates,
baseline_estimates,
next_value_estimates,
value_next,
value_memories,
baseline_memories,
) = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(),
trajectory.next_obs,

assert type(key) is str
assert len(val) == 15
# if all_memories is not None:
# assert len(all_memories) == 15
if value_memories is not None:
assert len(value_memories) == 15
assert len(baseline_memories) == 15
next_value_estimates,
value_next,
value_memories,
baseline_memories,
) = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(),
trajectory.next_obs,

for key, val in next_value_estimates.items():
for key, val in value_next.items():
optimizer.reward_signals["group"].use_terminal_states = False
optimizer.reward_signals["extrinsic"].use_terminal_states = False
next_value_estimates,
value_next,
value_memories,
baseline_memories,
) = optimizer.get_trajectory_and_baseline_value_estimates(
trajectory.to_agentbuffer(),
trajectory.next_obs,

for key, val in next_value_estimates.items():
for key, val in value_next.items():
assert type(key) is str
assert val != 0.0

BufferKey.ADVANTAGES,
RewardSignalUtil.returns_key("extrinsic"),
RewardSignalUtil.value_estimates_key("extrinsic"),
RewardSignalUtil.baseline_estimates_key("extrinsic"),
RewardSignalUtil.baseline_estimates_key("curiosity"),
copy_buffer_fields(update_buffer, BufferKey.MEMORY, [BufferKey.CRITIC_MEMORY])
copy_buffer_fields(
update_buffer,
BufferKey.MEMORY,
[BufferKey.CRITIC_MEMORY, BufferKey.BASELINE_MEMORY],
)
optimizer.update(
update_buffer,

BufferKey.ADVANTAGES,
RewardSignalUtil.returns_key("extrinsic"),
RewardSignalUtil.value_estimates_key("extrinsic"),
RewardSignalUtil.baseline_estimates_key("extrinsic"),
RewardSignalUtil.baseline_estimates_key("gail"),
],
)

BufferKey.ADVANTAGES,
RewardSignalUtil.returns_key("extrinsic"),
RewardSignalUtil.value_estimates_key("extrinsic"),
RewardSignalUtil.baseline_estimates_key("extrinsic"),
RewardSignalUtil.baseline_estimates_key("gail"),
],
)
optimizer.update(

正在加载...
取消
保存