|
|
|
|
|
|
from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs |
|
|
|
from mlagents.trainers.torch.action_log_probs import LogProbsTuple |
|
|
|
from mlagents.trainers.stats import StatsReporter |
|
|
|
from mlagents.trainers.behavior_id_utils import get_global_agent_id |
|
|
|
from mlagents.trainers.behavior_id_utils import ( |
|
|
|
get_global_agent_id, |
|
|
|
get_global_manager_id, |
|
|
|
) |
|
|
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
|
|
|
|
|
for terminal_step in terminal_steps.values(): |
|
|
|
local_id = terminal_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
self._gather_teammate_obs(terminal_step, global_id) |
|
|
|
self._gather_teammate_obs(terminal_step, global_id, worker_id) |
|
|
|
terminal_step, global_id, terminal_steps.agent_id_to_index[local_id] |
|
|
|
terminal_step, |
|
|
|
global_id, |
|
|
|
worker_id, |
|
|
|
terminal_steps.agent_id_to_index[local_id], |
|
|
|
) |
|
|
|
# Clear the last seen group obs when agents die. |
|
|
|
self._clear_teammate_obs(global_id) |
|
|
|
|
|
|
for ongoing_step in decision_steps.values(): |
|
|
|
local_id = ongoing_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
self._gather_teammate_obs(ongoing_step, global_id) |
|
|
|
self._gather_teammate_obs(ongoing_step, global_id, worker_id) |
|
|
|
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id] |
|
|
|
ongoing_step, |
|
|
|
global_id, |
|
|
|
worker_id, |
|
|
|
decision_steps.agent_id_to_index[local_id], |
|
|
|
) |
|
|
|
|
|
|
|
for _gid in action_global_agent_ids: |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
def _gather_teammate_obs( |
|
|
|
self, step: Union[TerminalStep, DecisionStep], global_id: str |
|
|
|
self, step: Union[TerminalStep, DecisionStep], global_id: str, worker_id: int |
|
|
|
self.last_group_obs[step.team_manager_id][ |
|
|
|
global_manager_id = get_global_manager_id( |
|
|
|
worker_id, step.team_manager_id |
|
|
|
) |
|
|
|
self.last_group_obs[global_manager_id][ |
|
|
|
self.current_group_obs[step.team_manager_id][global_id] = step.obs |
|
|
|
self.current_group_obs[global_manager_id][global_id] = step.obs |
|
|
|
|
|
|
|
def _clear_teammate_obs(self, global_id: str) -> None: |
|
|
|
for _manager_id in list(self.current_group_obs.keys()): |
|
|
|
|
|
|
self._safe_delete(self.last_group_obs, _manager_id) |
|
|
|
|
|
|
|
def _process_step( |
|
|
|
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int |
|
|
|
self, |
|
|
|
step: Union[TerminalStep, DecisionStep], |
|
|
|
global_id: str, |
|
|
|
worker_id: int, |
|
|
|
index: int, |
|
|
|
) -> None: |
|
|
|
terminated = isinstance(step, TerminalStep) |
|
|
|
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None)) |
|
|
|
|
|
|
|
|
|
|
# Assemble teammate_obs. If none saved, then it will be an empty list. |
|
|
|
collab_obs = [] |
|
|
|
for _id, _obs in self.last_group_obs[step.team_manager_id].items(): |
|
|
|
global_manager_id = get_global_manager_id(worker_id, step.team_manager_id) |
|
|
|
for _id, _obs in self.last_group_obs[global_manager_id].items(): |
|
|
|
if _id != global_id: |
|
|
|
collab_obs.append(_obs) |
|
|
|
|
|
|
|
|
|
|
): |
|
|
|
next_obs = step.obs |
|
|
|
next_collab_obs = [] |
|
|
|
for _id, _exp in self.current_group_obs[step.team_manager_id].items(): |
|
|
|
global_manager_id = get_global_manager_id( |
|
|
|
worker_id, step.team_manager_id |
|
|
|
) |
|
|
|
for _id, _exp in self.current_group_obs[global_manager_id].items(): |
|
|
|
if _id != global_id: |
|
|
|
next_collab_obs.append(_exp) |
|
|
|
|
|
|
|