|
|
|
|
|
|
from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs |
|
|
|
from mlagents.trainers.torch.action_log_probs import LogProbsTuple |
|
|
|
from mlagents.trainers.stats import StatsReporter |
|
|
|
from mlagents.trainers.behavior_id_utils import get_global_agent_id |
|
|
|
from mlagents.trainers.behavior_id_utils import get_global_agent_id, get_global_group_id |
|
|
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
|
|
|
|
|
# Iterate over all the terminal steps, first gather all the teammate obs |
|
|
|
# and then create the AgentExperiences/Trajectories |
|
|
|
for terminal_step in terminal_steps.values(): |
|
|
|
local_id = terminal_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
self._gather_group_obs(terminal_step, global_id) |
|
|
|
self._gather_group_obs(terminal_step, worker_id) |
|
|
|
terminal_step, global_id, terminal_steps.agent_id_to_index[local_id] |
|
|
|
terminal_step, worker_id, terminal_steps.agent_id_to_index[local_id] |
|
|
|
) |
|
|
|
# Clear the last seen group obs when agents die. |
|
|
|
self._clear_group_obs(global_id) |
|
|
|
|
|
|
# Iterate over all the decision steps, first gather all the teammate obs |
|
|
|
# and then create the trajectories |
|
|
|
for ongoing_step in decision_steps.values(): |
|
|
|
local_id = ongoing_step.agent_id |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
self._gather_group_obs(ongoing_step, global_id) |
|
|
|
self._gather_group_obs(ongoing_step, worker_id) |
|
|
|
global_id = get_global_agent_id(worker_id, local_id) |
|
|
|
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id] |
|
|
|
ongoing_step, worker_id, decision_steps.agent_id_to_index[local_id] |
|
|
|
) |
|
|
|
|
|
|
|
for _gid in action_global_agent_ids: |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
def _gather_group_obs( |
|
|
|
self, step: Union[TerminalStep, DecisionStep], global_id: str |
|
|
|
self, step: Union[TerminalStep, DecisionStep], worker_id: int |
|
|
|
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None)) |
|
|
|
stored_take_action_outputs = self.last_take_action_outputs.get(global_id, None) |
|
|
|
global_agent_id = get_global_agent_id(worker_id, step.agent_id) |
|
|
|
stored_decision_step, idx = self.last_step_result.get( |
|
|
|
global_agent_id, (None, None) |
|
|
|
) |
|
|
|
stored_take_action_outputs = self.last_take_action_outputs.get( |
|
|
|
global_agent_id, None |
|
|
|
) |
|
|
|
global_group_id = get_global_group_id(worker_id, step.group_id) |
|
|
|
stored_actions = stored_take_action_outputs["action"] |
|
|
|
action_tuple = ActionTuple( |
|
|
|
continuous=stored_actions.continuous[idx], |
|
|
|
|
|
|
action=action_tuple, |
|
|
|
done=isinstance(step, TerminalStep), |
|
|
|
) |
|
|
|
self.group_status[step.group_id][global_id] = group_status |
|
|
|
self.current_group_obs[step.group_id][global_id] = step.obs |
|
|
|
self.group_status[global_group_id][global_agent_id] = group_status |
|
|
|
self.current_group_obs[global_group_id][global_agent_id] = step.obs |
|
|
|
def _delete_in_nested_dict(self, nested_dict, key): |
|
|
|
def _delete_in_nested_dict(self, nested_dict: Dict[str, Any], key: str) -> None: |
|
|
|
for _manager_id in list(nested_dict.keys()): |
|
|
|
_team_group = nested_dict[_manager_id] |
|
|
|
self._safe_delete(_team_group, key) |
|
|
|
|
|
|
def _process_step( |
|
|
|
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int |
|
|
|
self, step: Union[TerminalStep, DecisionStep], worker_id: int, index: int |
|
|
|
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None)) |
|
|
|
stored_take_action_outputs = self.last_take_action_outputs.get(global_id, None) |
|
|
|
global_agent_id = get_global_agent_id(worker_id, step.agent_id) |
|
|
|
global_group_id = get_global_group_id(worker_id, step.group_id) |
|
|
|
stored_decision_step, idx = self.last_step_result.get( |
|
|
|
global_agent_id, (None, None) |
|
|
|
) |
|
|
|
stored_take_action_outputs = self.last_take_action_outputs.get( |
|
|
|
global_agent_id, None |
|
|
|
) |
|
|
|
self.last_step_result[global_id] = (step, index) |
|
|
|
self.last_step_result[global_agent_id] = (step, index) |
|
|
|
memory = self.policy.retrieve_memories([global_id])[0, :] |
|
|
|
memory = self.policy.retrieve_memories([global_agent_id])[0, :] |
|
|
|
else: |
|
|
|
memory = None |
|
|
|
done = terminated # Since this is an ongoing step |
|
|
|
|
|
|
discrete=stored_action_probs.discrete[idx], |
|
|
|
) |
|
|
|
action_mask = stored_decision_step.action_mask |
|
|
|
prev_action = self.policy.retrieve_previous_action([global_id])[0, :] |
|
|
|
prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :] |
|
|
|
for _id, _obs in self.group_status[step.group_id].items(): |
|
|
|
if _id != global_id: |
|
|
|
for _id, _obs in self.group_status[global_group_id].items(): |
|
|
|
if _id != global_agent_id: |
|
|
|
group_statuses.append(_obs) |
|
|
|
|
|
|
|
experience = AgentExperience( |
|
|
|
|
|
|
group_reward=step.group_reward, |
|
|
|
) |
|
|
|
# Add the value outputs if needed |
|
|
|
self.experience_buffers[global_id].append(experience) |
|
|
|
self.episode_rewards[global_id] += step.reward |
|
|
|
self.experience_buffers[global_agent_id].append(experience) |
|
|
|
self.episode_rewards[global_agent_id] += step.reward |
|
|
|
self.episode_steps[global_id] += 1 |
|
|
|
self.episode_steps[global_agent_id] += 1 |
|
|
|
len(self.experience_buffers[global_id]) >= self.max_trajectory_length |
|
|
|
len(self.experience_buffers[global_agent_id]) |
|
|
|
>= self.max_trajectory_length |
|
|
|
for _id, _exp in self.current_group_obs[step.group_id].items(): |
|
|
|
if _id != global_id: |
|
|
|
for _id, _exp in self.current_group_obs[global_group_id].items(): |
|
|
|
if _id != global_agent_id: |
|
|
|
steps=self.experience_buffers[global_id], |
|
|
|
agent_id=global_id, |
|
|
|
steps=self.experience_buffers[global_agent_id], |
|
|
|
agent_id=global_agent_id, |
|
|
|
next_obs=next_obs, |
|
|
|
next_group_obs=next_group_obs, |
|
|
|
behavior_id=self.behavior_id, |
|
|
|
|
|
|
self.experience_buffers[global_id] = [] |
|
|
|
self.experience_buffers[global_agent_id] = [] |
|
|
|
"Environment/Episode Length", self.episode_steps.get(global_id, 0) |
|
|
|
"Environment/Episode Length", |
|
|
|
self.episode_steps.get(global_agent_id, 0), |
|
|
|
self._clean_agent_data(global_id) |
|
|
|
self._clean_agent_data(global_agent_id) |
|
|
|
|
|
|
|
def _clean_agent_data(self, global_id: str) -> None: |
|
|
|
""" |
|
|
|