|
|
|
|
|
|
from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs |
|
|
|
from mlagents.trainers.torch.action_log_probs import LogProbsTuple |
|
|
|
from mlagents.trainers.stats import StatsReporter |
|
|
|
from mlagents.trainers.behavior_id_utils import ( |
|
|
|
get_global_agent_id, |
|
|
|
get_global_manager_id, |
|
|
|
) |
|
|
|
from mlagents.trainers.behavior_id_utils import get_global_agent_id, get_global_group_id |
|
|
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
|
|
|
|
|
) -> None: |
|
|
|
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None)) |
|
|
|
if stored_decision_step is not None: |
|
|
|
if step.team_manager_id > 0: |
|
|
|
global_manager_id = get_global_manager_id( |
|
|
|
worker_id, step.team_manager_id |
|
|
|
) |
|
|
|
self.last_group_obs[global_manager_id][ |
|
|
|
if step.group_id > 0: |
|
|
|
global_group_id = get_global_group_id(worker_id, step.group_id) |
|
|
|
self.last_group_obs[global_group_id][ |
|
|
|
self.current_group_obs[global_manager_id][global_id] = step.obs |
|
|
|
self.current_group_obs[global_group_id][global_id] = step.obs |
|
|
|
for _manager_id in list(self.current_group_obs.keys()): |
|
|
|
_team_group = self.current_group_obs[_manager_id] |
|
|
|
for _group_id in list(self.current_group_obs.keys()): |
|
|
|
_team_group = self.current_group_obs[_group_id] |
|
|
|
self._safe_delete(self.current_group_obs, _manager_id) |
|
|
|
for _manager_id in list(self.last_group_obs.keys()): |
|
|
|
_team_group = self.last_group_obs[_manager_id] |
|
|
|
self._safe_delete(self.current_group_obs, _group_id) |
|
|
|
for _group_id in list(self.last_group_obs.keys()): |
|
|
|
_team_group = self.last_group_obs[_group_id] |
|
|
|
self._safe_delete(self.last_group_obs, _manager_id) |
|
|
|
self._safe_delete(self.last_group_obs, _group_id) |
|
|
|
|
|
|
|
def _process_step( |
|
|
|
self, |
|
|
|
|
|
|
|
|
|
|
# Assemble teammate_obs. If none saved, then it will be an empty list. |
|
|
|
collab_obs = [] |
|
|
|
global_manager_id = get_global_manager_id(worker_id, step.team_manager_id) |
|
|
|
for _id, _obs in self.last_group_obs[global_manager_id].items(): |
|
|
|
global_group_id = get_global_group_id(worker_id, step.group_id) |
|
|
|
for _id, _obs in self.last_group_obs[global_group_id].items(): |
|
|
|
if _id != global_id: |
|
|
|
collab_obs.append(_obs) |
|
|
|
|
|
|
|
|
|
|
): |
|
|
|
next_obs = step.obs |
|
|
|
next_collab_obs = [] |
|
|
|
global_manager_id = get_global_manager_id( |
|
|
|
worker_id, step.team_manager_id |
|
|
|
) |
|
|
|
for _id, _exp in self.current_group_obs[global_manager_id].items(): |
|
|
|
global_group_id = get_global_group_id(worker_id, step.group_id) |
|
|
|
for _id, _exp in self.current_group_obs[global_group_id].items(): |
|
|
|
if _id != global_id: |
|
|
|
next_collab_obs.append(_exp) |
|
|
|
|
|
|
|