|
|
|
|
|
|
all_obs = [current_obs] + team_obs if team_obs is not None else [current_obs] |
|
|
|
all_next_value_mem: Optional[AgentBufferField] = None |
|
|
|
all_next_baseline_mem: Optional[AgentBufferField] = None |
|
|
|
if self.policy.use_recurrent: |
|
|
|
( |
|
|
|
value_estimates, |
|
|
|
baseline_estimates, |
|
|
|
all_next_value_mem, |
|
|
|
all_next_baseline_mem, |
|
|
|
next_value_mem, |
|
|
|
next_baseline_mem, |
|
|
|
) = self._evaluate_by_sequence_team( |
|
|
|
current_obs, team_obs, team_actions, _init_value_mem, _init_baseline_mem |
|
|
|
) |
|
|
|
else: |
|
|
|
value_estimates, next_value_mem = self.critic.critic_pass( |
|
|
|
all_obs, _init_value_mem, sequence_length=batch.num_experiences |
|
|
|
) |
|
|
|
with torch.no_grad(): |
|
|
|
if self.policy.use_recurrent: |
|
|
|
( |
|
|
|
value_estimates, |
|
|
|
baseline_estimates, |
|
|
|
all_next_value_mem, |
|
|
|
all_next_baseline_mem, |
|
|
|
next_value_mem, |
|
|
|
next_baseline_mem, |
|
|
|
) = self._evaluate_by_sequence_team( |
|
|
|
current_obs, |
|
|
|
team_obs, |
|
|
|
team_actions, |
|
|
|
_init_value_mem, |
|
|
|
_init_baseline_mem, |
|
|
|
) |
|
|
|
else: |
|
|
|
value_estimates, next_value_mem = self.critic.critic_pass( |
|
|
|
all_obs, _init_value_mem, sequence_length=batch.num_experiences |
|
|
|
) |
|
|
|
baseline_estimates, next_baseline_mem = self.critic.baseline( |
|
|
|
[current_obs], |
|
|
|
team_obs, |
|
|
|
team_actions, |
|
|
|
_init_baseline_mem, |
|
|
|
sequence_length=batch.num_experiences, |
|
|
|
) |
|
|
|
baseline_estimates, next_baseline_mem = self.critic.baseline( |
|
|
|
[current_obs], |
|
|
|
team_obs, |
|
|
|
team_actions, |
|
|
|
_init_baseline_mem, |
|
|
|
sequence_length=batch.num_experiences, |
|
|
|
) |
|
|
|
# Store the memory for the next trajectory |
|
|
|
self.value_memory_dict[agent_id] = next_value_mem |
|
|
|
self.baseline_memory_dict[agent_id] = next_baseline_mem |
|
|
|