|
|
|
|
|
|
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length |
|
|
|
) |
|
|
|
] |
|
|
|
offset = 1 if self.policy.sequence_length > 1 else 0 |
|
|
|
next_value_memories_list = [ |
|
|
|
ModelUtils.list_to_tensor( |
|
|
|
batch[BufferKey.CRITIC_MEMORY][i] |
|
|
|
) # only pass value part of memory to target network |
|
|
|
for i in range( |
|
|
|
offset, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length |
|
|
|
) |
|
|
|
] |
|
|
|
next_value_memories = torch.stack(next_value_memories_list).unsqueeze(0) |
|
|
|
next_value_memories = None |
|
|
|
torch.zeros_like(next_value_memories) |
|
|
|
if next_value_memories is not None |
|
|
|
else None |
|
|
|
torch.zeros_like(value_memories) if value_memories is not None else None |
|
|
|
) |
|
|
|
|
|
|
|
# Copy normalizers from policy |
|
|
|
|
|
|
q1_stream, q2_stream = q1_out, q2_out |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
# Since we didn't record the next value memories, evaluate one step in the critic to |
|
|
|
# get them. |
|
|
|
if value_memories is not None: |
|
|
|
# Get the first observation in each sequence |
|
|
|
just_first_obs = [ |
|
|
|
_obs[:: self.policy.sequence_length] for _obs in current_obs |
|
|
|
] |
|
|
|
_, next_value_memories = self._critic.critic_pass( |
|
|
|
just_first_obs, value_memories, sequence_length=1 |
|
|
|
) |
|
|
|
else: |
|
|
|
next_value_memories = None |
|
|
|
target_values, _ = self.target_network( |
|
|
|
next_obs, |
|
|
|
memories=next_value_memories, |
|
|
|