|
|
|
|
|
|
:param stats_category: The category under which to write the stats. Usually, this comes from the Trainer. |
|
|
|
""" |
|
|
|
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list) |
|
|
|
self.last_experience: Dict[str, AgentExperience] = {} |
|
|
|
lambda: defaultdict(list) |
|
|
|
) |
|
|
|
# last_group_obs is used to collect the last seen obs of all the agents in the same group, |
|
|
|
# and assemble the collab_obs. |
|
|
|
self.last_group_obs: Dict[str, Dict[str, List[np.ndarray]]] = defaultdict( |
|
|
|
lambda: defaultdict(list) |
|
|
|
) |
|
|
|
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while |
|
|
|
|
|
|
if global_id in self.last_step_result: # Don't store if agent just reset |
|
|
|
self.last_take_action_outputs[global_id] = take_action_outputs |
|
|
|
|
|
|
|
# Iterate over all the terminal steps |
|
|
|
# Iterate over all the terminal steps, first gather all the teammate obs |
|
|
|
# and then create the AgentExperiences/Trajectories |
|
|
|
self._process_step( |
|
|
|
terminal_step, global_id, terminal_steps.agent_id_to_index[local_id] |
|
|
|
) |
|
|
|
self._gather_teammate_obs(terminal_step, global_id) |
|
|
|
self._assemble_trajectory(terminal_step, global_id) |
|
|
|
self.current_group_obs.clear() |
|
|
|
self._process_step( |
|
|
|
terminal_step, global_id, terminal_steps.agent_id_to_index[local_id] |
|
|
|
) |
|
|
|
# Clear the last seen group obs when agents die. |
|
|
|
self._clear_teammate_obs(global_id) |
|
|
|
self._safe_delete(self.last_experience, global_id) |
|
|
|
# Iterate over all the decision steps |
|
|
|
# Iterate over all the decision steps, first gather all the teammate obs |
|
|
|
# and then create the trajectories |
|
|
|
self._process_step( |
|
|
|
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id] |
|
|
|
) |
|
|
|
self._gather_teammate_obs(ongoing_step, global_id) |
|
|
|
self._assemble_trajectory(ongoing_step, global_id) |
|
|
|
self.current_group_obs.clear() |
|
|
|
self._process_step( |
|
|
|
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id] |
|
|
|
) |
|
|
|
|
|
|
|
for _gid in action_global_agent_ids: |
|
|
|
# If the ID doesn't have a last step result, the agent just reset, |
|
|
|
|
|
|
[_gid], take_action_outputs["action"] |
|
|
|
) |
|
|
|
|
|
|
|
def _gather_teammate_obs( |
|
|
|
self, step: Union[TerminalStep, DecisionStep], global_id: str |
|
|
|
) -> None: |
|
|
|
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None)) |
|
|
|
if stored_decision_step is not None: |
|
|
|
if step.team_manager_id is not None: |
|
|
|
self.last_group_obs[step.team_manager_id][ |
|
|
|
global_id |
|
|
|
] = stored_decision_step.obs |
|
|
|
self.current_group_obs[step.team_manager_id][global_id] = step.obs |
|
|
|
|
|
|
|
def _clear_teammate_obs(self, global_id: str) -> None: |
|
|
|
for _manager_id, _team_group in self.current_group_obs.items(): |
|
|
|
self._safe_delete(_team_group, global_id) |
|
|
|
if not _team_group: # if dict is empty |
|
|
|
self._safe_delete(_team_group, _manager_id) |
|
|
|
for _manager_id, _team_group in self.last_group_obs.items(): |
|
|
|
self._safe_delete(_team_group, global_id) |
|
|
|
if not _team_group: # if dict is empty |
|
|
|
self._safe_delete(_team_group, _manager_id) |
|
|
|
|
|
|
|
def _process_step( |
|
|
|
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int |
|
|
|
) -> None: |
|
|
|
|
|
|
) |
|
|
|
action_mask = stored_decision_step.action_mask |
|
|
|
prev_action = self.policy.retrieve_previous_action([global_id])[0, :] |
|
|
|
|
|
|
|
# Assemble teammate_obs. If none saved, then it will be an empty list. |
|
|
|
collab_obs = [] |
|
|
|
for _id, _obs in self.last_group_obs[step.team_manager_id].items(): |
|
|
|
if _id != global_id: |
|
|
|
collab_obs.append(_obs) |
|
|
|
|
|
|
|
collab_obs=[], |
|
|
|
collab_obs=collab_obs, |
|
|
|
reward=step.reward, |
|
|
|
done=done, |
|
|
|
action=action_tuple, |
|
|
|
|
|
|
interrupted=interrupted, |
|
|
|
memory=memory, |
|
|
|
) |
|
|
|
if step.team_manager_id is not None: |
|
|
|
self.current_group_obs[step.team_manager_id][global_id] += step.obs |
|
|
|
self.last_experience[global_id] = experience |
|
|
|
|
|
|
|
def _assemble_trajectory( |
|
|
|
self, step: Union[TerminalStep, DecisionStep], global_id: str |
|
|
|
) -> None: |
|
|
|
if global_id in self.last_experience: |
|
|
|
experience = self.last_experience[global_id] |
|
|
|
terminated = isinstance(step, TerminalStep) |
|
|
|
|
|
|
|
# Add remaining shared obs to AgentExperience |
|
|
|
for _id, _exp in self.last_experience.items(): |
|
|
|
if _id == global_id: |
|
|
|
continue |
|
|
|
else: |
|
|
|
self.last_experience[global_id].collab_obs.append(_exp.obs) |
|
|
|
# Add the value outputs if needed |
|
|
|
self.experience_buffers[global_id].append(experience) |
|
|
|
self.episode_rewards[global_id] += step.reward |
|
|
|