|
|
|
|
|
|
|
|
|
|
trainer_settings = attr.evolve(dummy_config) |
|
|
|
trainer_settings.reward_signals = { |
|
|
|
RewardSignalType.EXTRINSIC: ExtrinsicSettings( |
|
|
|
strength=1.0, gamma=0.99, add_groupmate_rewards=True |
|
|
|
) |
|
|
|
RewardSignalType.EXTRINSIC: ExtrinsicSettings(strength=1.0, gamma=0.99) |
|
|
|
} |
|
|
|
|
|
|
|
trainer_settings.network_settings.memory = ( |
|
|
|
|
|
|
BufferKey.ENVIRONMENT_REWARDS, |
|
|
|
[ |
|
|
|
BufferKey.ADVANTAGES, |
|
|
|
RewardSignalUtil.returns_key("group"), |
|
|
|
RewardSignalUtil.value_estimates_key("group"), |
|
|
|
RewardSignalUtil.baseline_estimates_key("group"), |
|
|
|
RewardSignalUtil.returns_key("extrinsic"), |
|
|
|
RewardSignalUtil.value_estimates_key("extrinsic"), |
|
|
|
RewardSignalUtil.baseline_estimates_key("extrinsic"), |
|
|
|
copy_buffer_fields(update_buffer, BufferKey.MEMORY, [BufferKey.CRITIC_MEMORY]) |
|
|
|
copy_buffer_fields( |
|
|
|
update_buffer, |
|
|
|
BufferKey.MEMORY, |
|
|
|
[BufferKey.CRITIC_MEMORY, BufferKey.BASELINE_MEMORY], |
|
|
|
) |
|
|
|
|
|
|
|
return_stats = optimizer.update( |
|
|
|
update_buffer, |
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|
|
|
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|
|
|
@pytest.mark.parametrize("rnn", [False], ids=["no_rnn"]) |
|
|
|
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|
|
|
def test_coma_get_value_estimates(dummy_config, rnn, visual, discrete): |
|
|
|
optimizer = create_test_coma_optimizer( |
|
|
|
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|
|
|
|
|
|
( |
|
|
|
value_estimates, |
|
|
|
baseline_estimates, |
|
|
|
next_value_estimates, |
|
|
|
value_next, |
|
|
|
value_memories, |
|
|
|
baseline_memories, |
|
|
|
) = optimizer.get_trajectory_and_baseline_value_estimates( |
|
|
|
trajectory.to_agentbuffer(), |
|
|
|
trajectory.next_obs, |
|
|
|
|
|
|
assert type(key) is str |
|
|
|
assert len(val) == 15 |
|
|
|
|
|
|
|
# if all_memories is not None: |
|
|
|
# assert len(all_memories) == 15 |
|
|
|
if value_memories is not None: |
|
|
|
assert len(value_memories) == 15 |
|
|
|
assert len(baseline_memories) == 15 |
|
|
|
next_value_estimates, |
|
|
|
value_next, |
|
|
|
value_memories, |
|
|
|
baseline_memories, |
|
|
|
) = optimizer.get_trajectory_and_baseline_value_estimates( |
|
|
|
trajectory.to_agentbuffer(), |
|
|
|
trajectory.next_obs, |
|
|
|
|
|
|
for key, val in next_value_estimates.items(): |
|
|
|
for key, val in value_next.items(): |
|
|
|
optimizer.reward_signals["group"].use_terminal_states = False |
|
|
|
optimizer.reward_signals["extrinsic"].use_terminal_states = False |
|
|
|
next_value_estimates, |
|
|
|
value_next, |
|
|
|
value_memories, |
|
|
|
baseline_memories, |
|
|
|
) = optimizer.get_trajectory_and_baseline_value_estimates( |
|
|
|
trajectory.to_agentbuffer(), |
|
|
|
trajectory.next_obs, |
|
|
|
|
|
|
for key, val in next_value_estimates.items(): |
|
|
|
for key, val in value_next.items(): |
|
|
|
assert type(key) is str |
|
|
|
assert val != 0.0 |
|
|
|
|
|
|
|
|
|
|
BufferKey.ADVANTAGES, |
|
|
|
RewardSignalUtil.returns_key("extrinsic"), |
|
|
|
RewardSignalUtil.value_estimates_key("extrinsic"), |
|
|
|
RewardSignalUtil.baseline_estimates_key("extrinsic"), |
|
|
|
RewardSignalUtil.baseline_estimates_key("curiosity"), |
|
|
|
copy_buffer_fields(update_buffer, BufferKey.MEMORY, [BufferKey.CRITIC_MEMORY]) |
|
|
|
copy_buffer_fields( |
|
|
|
update_buffer, |
|
|
|
BufferKey.MEMORY, |
|
|
|
[BufferKey.CRITIC_MEMORY, BufferKey.BASELINE_MEMORY], |
|
|
|
) |
|
|
|
|
|
|
|
optimizer.update( |
|
|
|
update_buffer, |
|
|
|
|
|
|
BufferKey.ADVANTAGES, |
|
|
|
RewardSignalUtil.returns_key("extrinsic"), |
|
|
|
RewardSignalUtil.value_estimates_key("extrinsic"), |
|
|
|
RewardSignalUtil.baseline_estimates_key("extrinsic"), |
|
|
|
RewardSignalUtil.baseline_estimates_key("gail"), |
|
|
|
], |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
BufferKey.ADVANTAGES, |
|
|
|
RewardSignalUtil.returns_key("extrinsic"), |
|
|
|
RewardSignalUtil.value_estimates_key("extrinsic"), |
|
|
|
RewardSignalUtil.baseline_estimates_key("extrinsic"), |
|
|
|
RewardSignalUtil.baseline_estimates_key("gail"), |
|
|
|
], |
|
|
|
) |
|
|
|
optimizer.update( |
|
|
|