浏览代码

Fix AgentProcessor for TeamManager

Should work for variable decision frequencies (untested)
/develop/centralizedcritic
Ervin Teng 4 年前
当前提交
a7e368b8
共有 1 个文件被更改,包括 48 次插入32 次删除
  1. 80
      ml-agents/mlagents/trainers/agent_processor.py

80
ml-agents/mlagents/trainers/agent_processor.py


:param stats_category: The category under which to write the stats. Usually, this comes from the Trainer.
"""
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list)
self.last_experience: Dict[str, AgentExperience] = {}
lambda: defaultdict(list)
)
# last_group_obs is used to collect the last seen obs of all the agents in the same group,
# and assemble the collab_obs.
self.last_group_obs: Dict[str, Dict[str, List[np.ndarray]]] = defaultdict(
lambda: defaultdict(list)
)
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while

if global_id in self.last_step_result: # Don't store if agent just reset
self.last_take_action_outputs[global_id] = take_action_outputs
# Iterate over all the terminal steps
# Iterate over all the terminal steps, first gather all the teammate obs
# and then create the AgentExperiences/Trajectories
self._process_step(
terminal_step, global_id, terminal_steps.agent_id_to_index[local_id]
)
self._gather_teammate_obs(terminal_step, global_id)
self._assemble_trajectory(terminal_step, global_id)
self.current_group_obs.clear()
self._process_step(
terminal_step, global_id, terminal_steps.agent_id_to_index[local_id]
)
# Clear the last seen group obs when agents die.
self._clear_teammate_obs(global_id)
self._safe_delete(self.last_experience, global_id)
# Iterate over all the decision steps
# Iterate over all the decision steps, first gather all the teammate obs
# and then create the trajectories
self._process_step(
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id]
)
self._gather_teammate_obs(ongoing_step, global_id)
self._assemble_trajectory(ongoing_step, global_id)
self.current_group_obs.clear()
self._process_step(
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id]
)
for _gid in action_global_agent_ids:
# If the ID doesn't have a last step result, the agent just reset,

[_gid], take_action_outputs["action"]
)
def _gather_teammate_obs(
self, step: Union[TerminalStep, DecisionStep], global_id: str
) -> None:
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))
if stored_decision_step is not None:
if step.team_manager_id is not None:
self.last_group_obs[step.team_manager_id][
global_id
] = stored_decision_step.obs
self.current_group_obs[step.team_manager_id][global_id] = step.obs
def _clear_teammate_obs(self, global_id: str) -> None:
for _manager_id, _team_group in self.current_group_obs.items():
self._safe_delete(_team_group, global_id)
if not _team_group: # if dict is empty
self._safe_delete(_team_group, _manager_id)
for _manager_id, _team_group in self.last_group_obs.items():
self._safe_delete(_team_group, global_id)
if not _team_group: # if dict is empty
self._safe_delete(_team_group, _manager_id)
def _process_step(
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int
) -> None:

)
action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_id])[0, :]
# Assemble teammate_obs. If none saved, then it will be an empty list.
collab_obs = []
for _id, _obs in self.last_group_obs[step.team_manager_id].items():
if _id != global_id:
collab_obs.append(_obs)
collab_obs=[],
collab_obs=collab_obs,
reward=step.reward,
done=done,
action=action_tuple,

interrupted=interrupted,
memory=memory,
)
if step.team_manager_id is not None:
self.current_group_obs[step.team_manager_id][global_id] += step.obs
self.last_experience[global_id] = experience
def _assemble_trajectory(
self, step: Union[TerminalStep, DecisionStep], global_id: str
) -> None:
if global_id in self.last_experience:
experience = self.last_experience[global_id]
terminated = isinstance(step, TerminalStep)
# Add remaining shared obs to AgentExperience
for _id, _exp in self.last_experience.items():
if _id == global_id:
continue
else:
self.last_experience[global_id].collab_obs.append(_exp.obs)
# Add the value outputs if needed
self.experience_buffers[global_id].append(experience)
self.episode_rewards[global_id] += step.reward

正在加载...
取消
保存