|
|
|
|
|
|
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list) |
|
|
|
self.last_experience: Dict[str, AgentExperience] = {} |
|
|
|
self.last_step_result: Dict[str, Tuple[DecisionStep, int]] = {} |
|
|
|
# current_obs is used to collect the last seen obs of all the agents, and assemble the next_collab_obs. |
|
|
|
self.current_obs: Dict[str, List[np.ndarray]] = {} |
|
|
|
# current_group_obs is used to collect the last seen obs of all the agents, and assemble the next_collab_obs. |
|
|
|
self.current_group_obs: Dict[str, Dict[str, List[np.ndarray]]] = defaultdict( |
|
|
|
lambda: defaultdict(list) |
|
|
|
) |
|
|
|
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while |
|
|
|
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1). |
|
|
|
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {} |
|
|
|
|
|
|
self.last_take_action_outputs[global_id] = take_action_outputs |
|
|
|
|
|
|
|
# Iterate over all the terminal steps |
|
|
|
# print("processing terminal_step") |
|
|
|
for terminal_step in terminal_steps.values(): |
|
|
|
local_id = terminal_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
|
|
|
local_id = terminal_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
self._assemble_trajectory(terminal_step, global_id) |
|
|
|
self.current_obs.clear() |
|
|
|
self.current_group_obs.clear() |
|
|
|
# print("clear terminal_step") |
|
|
|
|
|
|
|
# Clean the last experience dictionary for terminal steps |
|
|
|
for terminal_step in terminal_steps.values(): |
|
|
|
|
|
|
|
|
|
|
# Iterate over all the decision steps |
|
|
|
# print("processing decision_steps") |
|
|
|
for ongoing_step in decision_steps.values(): |
|
|
|
local_id = ongoing_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
|
|
|
local_id = ongoing_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
self._assemble_trajectory(ongoing_step, global_id) |
|
|
|
self.current_obs.clear() |
|
|
|
self.current_group_obs.clear() |
|
|
|
# print("clear decision_steps") |
|
|
|
|
|
|
|
for _gid in action_global_agent_ids: |
|
|
|
# If the ID doesn't have a last step result, the agent just reset, |
|
|
|
|
|
|
interrupted=interrupted, |
|
|
|
memory=memory, |
|
|
|
) |
|
|
|
self.current_obs[global_id] = step.obs |
|
|
|
if step.team_manager_id is not None: |
|
|
|
self.current_group_obs[step.team_manager_id][global_id] += step.obs |
|
|
|
self.last_experience[global_id] = experience |
|
|
|
|
|
|
|
def _assemble_trajectory( |
|
|
|
|
|
|
): |
|
|
|
next_obs = step.obs |
|
|
|
next_collab_obs = [] |
|
|
|
for _id, _exp in self.current_obs.items(): |
|
|
|
for _id, _exp in self.current_group_obs[step.team_manager_id].items(): |
|
|
|
if _id == global_id: |
|
|
|
continue |
|
|
|
else: |
|
|
|