|
|
|
|
|
|
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list) |
|
|
|
self.last_step_result: Dict[str, Tuple[DecisionStep, int]] = {} |
|
|
|
# current_group_obs is used to collect the last seen obs of all the agents in the same group, |
|
|
|
# and assemble the next_collab_obs. |
|
|
|
# and assemble the collab_obs. |
|
|
|
self.current_group_obs: Dict[str, Dict[str, List[np.ndarray]]] = defaultdict( |
|
|
|
lambda: defaultdict(list) |
|
|
|
) |
|
|
|
|
|
|
lambda: defaultdict(list) |
|
|
|
) |
|
|
|
# current_group_rewards is used to collect the last seen rewards of all the agents in the same group. |
|
|
|
self.current_group_rewards: Dict[str, Dict[str, float]] = defaultdict( |
|
|
|
lambda: defaultdict(float) |
|
|
|
) |
|
|
|
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while |
|
|
|
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1). |
|
|
|
|
|
|
global_id |
|
|
|
] = stored_decision_step.obs |
|
|
|
self.current_group_obs[step.team_manager_id][global_id] = step.obs |
|
|
|
self.current_group_rewards[step.team_manager_id][ |
|
|
|
global_id |
|
|
|
] = step.reward |
|
|
|
for _manager_id, _team_group in self.current_group_obs.items(): |
|
|
|
self._safe_delete(_team_group, global_id) |
|
|
|
if not _team_group: # if dict is empty |
|
|
|
self._safe_delete(_team_group, _manager_id) |
|
|
|
for _manager_id, _team_group in self.last_group_obs.items(): |
|
|
|
self._safe_delete(_team_group, global_id) |
|
|
|
self._delete_in_nested_dict(self.current_group_obs, global_id) |
|
|
|
self._delete_in_nested_dict(self.last_group_obs, global_id) |
|
|
|
self._delete_in_nested_dict(self.current_group_rewards, global_id) |
|
|
|
|
|
|
|
def _delete_in_nested_dict(self, nested_dict, key): |
|
|
|
for _manager_id, _team_group in nested_dict.items(): |
|
|
|
self._safe_delete(_team_group, key) |
|
|
|
if not _team_group: # if dict is empty |
|
|
|
self._safe_delete(_team_group, _manager_id) |
|
|
|
|
|
|
|
|
|
|
for _id, _obs in self.last_group_obs[step.team_manager_id].items(): |
|
|
|
if _id != global_id: |
|
|
|
collab_obs.append(_obs) |
|
|
|
teammate_rewards = [] |
|
|
|
for _id, _rew in self.current_group_rewards[step.team_manager_id].items(): |
|
|
|
if _id != global_id: |
|
|
|
teammate_rewards.append(_rew) |
|
|
|
team_rewards=teammate_rewards, |
|
|
|
done=done, |
|
|
|
action=action_tuple, |
|
|
|
action_probs=log_probs_tuple, |
|
|
|