|
|
|
|
|
|
] |
|
|
|
if len(memories) > 0: |
|
|
|
memories = torch.stack(memories).unsqueeze(0) |
|
|
|
value_memories = [ |
|
|
|
ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) |
|
|
|
for i in range( |
|
|
|
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length |
|
|
|
) |
|
|
|
] |
|
|
|
|
|
|
|
baseline_memories = [ |
|
|
|
ModelUtils.list_to_tensor(batch[BufferKey.BASELINE_MEMORY][i]) |
|
|
|
for i in range( |
|
|
|
0, len(batch[BufferKey.BASELINE_MEMORY]), self.policy.sequence_length |
|
|
|
) |
|
|
|
] |
|
|
|
|
|
|
|
if len(value_memories) > 0: |
|
|
|
value_memories = torch.stack(value_memories).unsqueeze(0) |
|
|
|
baseline_memories = torch.stack(baseline_memories).unsqueeze(0) |
|
|
|
|
|
|
|
log_probs, entropy = self.policy.evaluate_actions( |
|
|
|
current_obs, |
|
|
|
|
|
|
) |
|
|
|
all_obs = [current_obs] + group_obs |
|
|
|
values, _ = self.critic.critic_pass( |
|
|
|
all_obs, memories=memories, sequence_length=self.policy.sequence_length |
|
|
|
all_obs, |
|
|
|
memories=value_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
memories=memories, |
|
|
|
memories=baseline_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
old_log_probs = ActionLogProbs.from_buffer(batch).flatten() |
|
|
|
|
|
|
|
|
|
|
for team_obs, team_action in zip(obs, actions): |
|
|
|
seq_obs = [] |
|
|
|
for (_obs,) in team_obs: |
|
|
|
for _obs in team_obs: |
|
|
|
first_seq_obs = _obs[0:first_seq_len] |
|
|
|
seq_obs.append(first_seq_obs) |
|
|
|
team_seq_obs.append(seq_obs) |
|
|
|
|
|
|
_init_value_mem = self.value_memory_dict[agent_id] |
|
|
|
_init_baseline_mem = self.baseline_memory_dict[agent_id] |
|
|
|
else: |
|
|
|
memory = ( |
|
|
|
_init_value_mem = ( |
|
|
|
torch.zeros((1, 1, self.critic.memory_size)) |
|
|
|
if self.policy.use_recurrent |
|
|
|
else None |
|
|
|
) |
|
|
|
_init_baseline_mem = ( |
|
|
|
torch.zeros((1, 1, self.critic.memory_size)) |
|
|
|
if self.policy.use_recurrent |
|
|
|
else None |
|
|
|
|
|
|
all_next_value_mem: Optional[AgentBufferField] = None |
|
|
|
all_next_baseline_mem: Optional[AgentBufferField] = None |
|
|
|
if self.policy.use_recurrent: |
|
|
|
value_estimates, baseline_estimates, all_next_value_mem, all_next_baseline_mem, next_value_mem, next_baseline_mem = self.critic._evaluate_by_sequence_team( |
|
|
|
value_estimates, baseline_estimates, all_next_value_mem, all_next_baseline_mem, next_value_mem, next_baseline_mem = self._evaluate_by_sequence_team( |
|
|
|
all_obs, memory, sequence_length=batch.num_experiences |
|
|
|
all_obs, _init_value_mem, sequence_length=batch.num_experiences |
|
|
|
) |
|
|
|
|
|
|
|
baseline_estimates, baseline_mem = self.critic.baseline( |
|
|
|
|
|
|
memory, |
|
|
|
_init_baseline_mem, |
|
|
|
sequence_length=batch.num_experiences, |
|
|
|
) |
|
|
|
# Store the memory for the next trajectory |
|
|
|