|
|
|
|
|
|
for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) |
|
|
|
] |
|
|
|
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. |
|
|
|
value_memories_list = [ |
|
|
|
ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) |
|
|
|
for i in range( |
|
|
|
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length |
|
|
|
) |
|
|
|
] |
|
|
|
next_memories_list = [ |
|
|
|
next_value_memories_list = [ |
|
|
|
batch[BufferKey.MEMORY][i] |
|
|
|
batch[BufferKey.CRITIC_MEMORY][i] |
|
|
|
offset, len(batch[BufferKey.MEMORY]), self.policy.sequence_length |
|
|
|
offset, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length |
|
|
|
next_memories = torch.stack(next_memories_list).unsqueeze(0) |
|
|
|
value_memories = torch.stack(value_memories_list).unsqueeze(0) |
|
|
|
next_value_memories = torch.stack(next_value_memories_list).unsqueeze(0) |
|
|
|
next_memories = None |
|
|
|
value_memories = None |
|
|
|
next_value_memories = None |
|
|
|
torch.zeros_like(next_memories) if next_memories is not None else None |
|
|
|
) |
|
|
|
v_memories = ( |
|
|
|
torch.zeros_like(next_memories) if next_memories is not None else None |
|
|
|
torch.zeros_like(next_value_memories) |
|
|
|
if next_value_memories is not None |
|
|
|
else None |
|
|
|
) |
|
|
|
|
|
|
|
# Copy normalizers from policy |
|
|
|
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
value_estimates, _ = self.value_network.critic_pass( |
|
|
|
current_obs, v_memories, sequence_length=self.policy.sequence_length |
|
|
|
current_obs, value_memories, sequence_length=self.policy.sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
cont_sampled_actions = sampled_actions.continuous_tensor |
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
target_values, _ = self.target_network( |
|
|
|
next_obs, |
|
|
|
memories=next_memories, |
|
|
|
memories=next_value_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) |
|
|
|