浏览代码

Global group ids

/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
a25bb4d4
共有 2 个文件被更改,包括 52 次插入36 次删除
  1. 79
      ml-agents/mlagents/trainers/agent_processor.py
  2. 9
      ml-agents/mlagents/trainers/behavior_id_utils.py

79
ml-agents/mlagents/trainers/agent_processor.py


from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.behavior_id_utils import get_global_agent_id
from mlagents.trainers.behavior_id_utils import get_global_agent_id, get_global_group_id
T = TypeVar("T")

# Iterate over all the terminal steps, first gather all the teammate obs
# and then create the AgentExperiences/Trajectories
for terminal_step in terminal_steps.values():
local_id = terminal_step.agent_id
global_id = get_global_agent_id(worker_id, local_id)
self._gather_group_obs(terminal_step, global_id)
self._gather_group_obs(terminal_step, worker_id)
terminal_step, global_id, terminal_steps.agent_id_to_index[local_id]
terminal_step, worker_id, terminal_steps.agent_id_to_index[local_id]
)
# Clear the last seen group obs when agents die.
self._clear_group_obs(global_id)

# Iterate over all the decision steps, first gather all the teammate obs
# and then create the trajectories
for ongoing_step in decision_steps.values():
local_id = ongoing_step.agent_id
global_id = get_global_agent_id(worker_id, local_id)
self._gather_group_obs(ongoing_step, global_id)
self._gather_group_obs(ongoing_step, worker_id)
global_id = get_global_agent_id(worker_id, local_id)
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id]
ongoing_step, worker_id, decision_steps.agent_id_to_index[local_id]
)
for _gid in action_global_agent_ids:

)
def _gather_group_obs(
self, step: Union[TerminalStep, DecisionStep], global_id: str
self, step: Union[TerminalStep, DecisionStep], worker_id: int
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))
stored_take_action_outputs = self.last_take_action_outputs.get(global_id, None)
global_agent_id = get_global_agent_id(worker_id, step.agent_id)
stored_decision_step, idx = self.last_step_result.get(
global_agent_id, (None, None)
)
stored_take_action_outputs = self.last_take_action_outputs.get(
global_agent_id, None
)
global_group_id = get_global_group_id(worker_id, step.group_id)
stored_actions = stored_take_action_outputs["action"]
action_tuple = ActionTuple(
continuous=stored_actions.continuous[idx],

action=action_tuple,
done=isinstance(step, TerminalStep),
)
self.group_status[step.group_id][global_id] = group_status
self.current_group_obs[step.group_id][global_id] = step.obs
self.group_status[global_group_id][global_agent_id] = group_status
self.current_group_obs[global_group_id][global_agent_id] = step.obs
def _delete_in_nested_dict(self, nested_dict, key):
def _delete_in_nested_dict(self, nested_dict: Dict[str, Any], key: str) -> None:
for _manager_id in list(nested_dict.keys()):
_team_group = nested_dict[_manager_id]
self._safe_delete(_team_group, key)

def _process_step(
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int
self, step: Union[TerminalStep, DecisionStep], worker_id: int, index: int
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))
stored_take_action_outputs = self.last_take_action_outputs.get(global_id, None)
global_agent_id = get_global_agent_id(worker_id, step.agent_id)
global_group_id = get_global_group_id(worker_id, step.group_id)
stored_decision_step, idx = self.last_step_result.get(
global_agent_id, (None, None)
)
stored_take_action_outputs = self.last_take_action_outputs.get(
global_agent_id, None
)
self.last_step_result[global_id] = (step, index)
self.last_step_result[global_agent_id] = (step, index)
memory = self.policy.retrieve_memories([global_id])[0, :]
memory = self.policy.retrieve_memories([global_agent_id])[0, :]
else:
memory = None
done = terminated # Since this is an ongoing step

discrete=stored_action_probs.discrete[idx],
)
action_mask = stored_decision_step.action_mask
prev_action = self.policy.retrieve_previous_action([global_id])[0, :]
prev_action = self.policy.retrieve_previous_action([global_agent_id])[0, :]
for _id, _obs in self.group_status[step.group_id].items():
if _id != global_id:
for _id, _obs in self.group_status[global_group_id].items():
if _id != global_agent_id:
group_statuses.append(_obs)
experience = AgentExperience(

group_reward=step.group_reward,
)
# Add the value outputs if needed
self.experience_buffers[global_id].append(experience)
self.episode_rewards[global_id] += step.reward
self.experience_buffers[global_agent_id].append(experience)
self.episode_rewards[global_agent_id] += step.reward
self.episode_steps[global_id] += 1
self.episode_steps[global_agent_id] += 1
len(self.experience_buffers[global_id]) >= self.max_trajectory_length
len(self.experience_buffers[global_agent_id])
>= self.max_trajectory_length
for _id, _exp in self.current_group_obs[step.group_id].items():
if _id != global_id:
for _id, _exp in self.current_group_obs[global_group_id].items():
if _id != global_agent_id:
steps=self.experience_buffers[global_id],
agent_id=global_id,
steps=self.experience_buffers[global_agent_id],
agent_id=global_agent_id,
next_obs=next_obs,
next_group_obs=next_group_obs,
behavior_id=self.behavior_id,

self.experience_buffers[global_id] = []
self.experience_buffers[global_agent_id] = []
"Environment/Episode Length", self.episode_steps.get(global_id, 0)
"Environment/Episode Length",
self.episode_steps.get(global_agent_id, 0),
self._clean_agent_data(global_id)
self._clean_agent_data(global_agent_id)
def _clean_agent_data(self, global_id: str) -> None:
"""

9
ml-agents/mlagents/trainers/behavior_id_utils.py


"""
Create an agent id that is unique across environment workers using the worker_id.
"""
return f"${worker_id}-{agent_id}"
return f"agent_{worker_id}-{agent_id}"
def get_global_group_id(worker_id: int, group_id: int) -> str:
"""
Create a group id that is unique across environment workers when using the worker_id.
"""
return f"group_{worker_id}-{group_id}"
正在加载...
取消
保存