|
|
|
|
|
|
StatsAggregationMethod, |
|
|
|
EnvironmentStats, |
|
|
|
) |
|
|
|
from mlagents.trainers.trajectory import Trajectory, AgentExperience |
|
|
|
from mlagents.trainers.trajectory import TeammateStatus, Trajectory, AgentExperience |
|
|
|
from mlagents.trainers.policy import Policy |
|
|
|
from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs |
|
|
|
from mlagents.trainers.torch.action_log_probs import LogProbsTuple |
|
|
|
|
|
|
) |
|
|
|
# last_group_obs is used to collect the last seen obs of all the agents in the same group, |
|
|
|
# and assemble the collab_obs. |
|
|
|
self.last_group_obs: Dict[str, Dict[str, List[np.ndarray]]] = defaultdict( |
|
|
|
lambda: defaultdict(list) |
|
|
|
) |
|
|
|
# current_group_rewards is used to collect the last seen rewards of all the agents in the same group. |
|
|
|
self.current_group_rewards: Dict[str, Dict[str, float]] = defaultdict( |
|
|
|
lambda: defaultdict(float) |
|
|
|
self.teammate_status: Dict[str, Dict[str, TeammateStatus]] = defaultdict( |
|
|
|
lambda: defaultdict(None) |
|
|
|
) |
|
|
|
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while |
|
|
|
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1). |
|
|
|
|
|
|
self, step: Union[TerminalStep, DecisionStep], global_id: str |
|
|
|
) -> None: |
|
|
|
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None)) |
|
|
|
if stored_decision_step is not None: |
|
|
|
stored_take_action_outputs = self.last_take_action_outputs.get(global_id, None) |
|
|
|
if stored_decision_step is not None and stored_take_action_outputs is not None: |
|
|
|
self.last_group_obs[step.team_manager_id][ |
|
|
|
global_id |
|
|
|
] = stored_decision_step.obs |
|
|
|
stored_actions = stored_take_action_outputs["action"] |
|
|
|
action_tuple = ActionTuple( |
|
|
|
continuous=stored_actions.continuous[idx], |
|
|
|
discrete=stored_actions.discrete[idx], |
|
|
|
) |
|
|
|
teammate_status = TeammateStatus( |
|
|
|
obs=stored_decision_step.obs, |
|
|
|
reward=step.reward, |
|
|
|
action=action_tuple, |
|
|
|
) |
|
|
|
self.teammate_status[step.team_manager_id][global_id] = teammate_status |
|
|
|
self.current_group_rewards[step.team_manager_id][ |
|
|
|
global_id |
|
|
|
] = step.reward |
|
|
|
self._delete_in_nested_dict(self.last_group_obs, global_id) |
|
|
|
self._delete_in_nested_dict(self.current_group_rewards, global_id) |
|
|
|
self._delete_in_nested_dict(self.teammate_status, global_id) |
|
|
|
|
|
|
|
def _delete_in_nested_dict(self, nested_dict, key): |
|
|
|
for _manager_id, _team_group in nested_dict.items(): |
|
|
|
|
|
|
prev_action = self.policy.retrieve_previous_action([global_id])[0, :] |
|
|
|
|
|
|
|
# Assemble teammate_obs. If none saved, then it will be an empty list. |
|
|
|
collab_obs = [] |
|
|
|
for _id, _obs in self.last_group_obs[step.team_manager_id].items(): |
|
|
|
teammate_statuses = [] |
|
|
|
for _id, _obs in self.teammate_status[step.team_manager_id].items(): |
|
|
|
collab_obs.append(_obs) |
|
|
|
teammate_rewards = [] |
|
|
|
for _id, _rew in self.current_group_rewards[step.team_manager_id].items(): |
|
|
|
if _id != global_id: |
|
|
|
teammate_rewards.append(_rew) |
|
|
|
teammate_statuses.append(_obs) |
|
|
|
collab_obs=collab_obs, |
|
|
|
teammate_status=teammate_statuses, |
|
|
|
team_rewards=teammate_rewards, |
|
|
|
done=done, |
|
|
|
action=action_tuple, |
|
|
|
action_probs=log_probs_tuple, |
|
|
|