|
|
|
|
|
|
for terminal_step in terminal_steps.values(): |
|
|
|
local_id = terminal_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
self._gather_teammate_obs(terminal_step, global_id) |
|
|
|
self._gather_group_obs(terminal_step, global_id) |
|
|
|
for terminal_step in terminal_steps.values(): |
|
|
|
local_id = terminal_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
|
|
|
# Clear the last seen group obs when agents die. |
|
|
|
self._clear_teammate_obs(global_id) |
|
|
|
self._clear_group_obs(global_id) |
|
|
|
|
|
|
|
# Clean the last experience dictionary for terminal steps |
|
|
|
for terminal_step in terminal_steps.values(): |
|
|
|
|
|
|
for ongoing_step in decision_steps.values(): |
|
|
|
local_id = ongoing_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
self._gather_teammate_obs(ongoing_step, global_id) |
|
|
|
self._gather_group_obs(ongoing_step, global_id) |
|
|
|
for ongoing_step in decision_steps.values(): |
|
|
|
local_id = ongoing_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
|
|
|
[_gid], take_action_outputs["action"] |
|
|
|
) |
|
|
|
|
|
|
|
def _gather_teammate_obs( |
|
|
|
def _gather_group_obs( |
|
|
|
self, step: Union[TerminalStep, DecisionStep], global_id: str |
|
|
|
) -> None: |
|
|
|
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None)) |
|
|
|
|
|
|
self.group_status[step.group_id][global_id] = group_status |
|
|
|
self.current_group_obs[step.group_id][global_id] = step.obs |
|
|
|
|
|
|
|
def _clear_teammate_obs(self, global_id: str) -> None: |
|
|
|
def _clear_group_obs(self, global_id: str) -> None: |
|
|
|
self._delete_in_nested_dict(self.current_group_obs, global_id) |
|
|
|
self._delete_in_nested_dict(self.group_status, global_id) |
|
|
|
|
|
|
|