|
|
|
|
|
|
|
|
|
|
# If we're using LSTM, we want to get all the intermediate memories. |
|
|
|
all_next_memories: Optional[AgentBufferField] = None |
|
|
|
if self.policy.use_recurrent: |
|
|
|
( |
|
|
|
value_estimates, |
|
|
|
all_next_memories, |
|
|
|
next_memory, |
|
|
|
) = self._evaluate_by_sequence(current_obs, memory) |
|
|
|
else: |
|
|
|
value_estimates, next_memory = self.critic.critic_pass( |
|
|
|
current_obs, memory, sequence_length=batch.num_experiences |
|
|
|
) |
|
|
|
|
|
|
|
# To prevent memory leak and improve performance, evaluate with no_grad. |
|
|
|
with torch.no_grad(): |
|
|
|
if self.policy.use_recurrent: |
|
|
|
( |
|
|
|
value_estimates, |
|
|
|
all_next_memories, |
|
|
|
next_memory, |
|
|
|
) = self._evaluate_by_sequence(current_obs, memory) |
|
|
|
else: |
|
|
|
value_estimates, next_memory = self.critic.critic_pass( |
|
|
|
current_obs, memory, sequence_length=batch.num_experiences |
|
|
|
) |
|
|
|
# Store the memory for the next trajectory |
|
|
|
# Store the memory for the next trajectory. This should NOT have a gradient. |
|
|
|
self.critic_memory_dict[agent_id] = next_memory |
|
|
|
|
|
|
|
next_value_estimate, _ = self.critic.critic_pass( |
|
|
|