浏览代码

fix grouping for int id

/develop/teammanager/int
Ruo-Ping Dong 4 年前
当前提交
fb4a3bd2
共有 3 个文件被更改,包括 47 次插入14 次删除
  1. 1
      ml-agents-envs/mlagents_envs/base_env.py
  2. 53
      ml-agents/mlagents/trainers/agent_processor.py
  3. 7
      ml-agents/mlagents/trainers/behavior_id_utils.py

1
ml-agents-envs/mlagents_envs/base_env.py


self.agent_id: np.ndarray = agent_id
self.team_manager_id: np.ndarray = team_manager_id
self._agent_id_to_index: Optional[Dict[AgentId, int]] = None
self.team_manager_id: Optional[List[str]] = team_manager_id
@property
def agent_id_to_index(self) -> Dict[AgentId, int]:

53
ml-agents/mlagents/trainers/agent_processor.py


from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.behavior_id_utils import get_global_agent_id
from mlagents.trainers.behavior_id_utils import (
get_global_agent_id,
get_global_manager_id,
)
T = TypeVar("T")

for terminal_step in terminal_steps.values():
local_id = terminal_step.agent_id
global_id = get_global_agent_id(worker_id, local_id)
self._gather_teammate_obs(terminal_step, global_id)
self._gather_teammate_obs(terminal_step, global_id, worker_id)
terminal_step, global_id, terminal_steps.agent_id_to_index[local_id]
terminal_step,
global_id,
worker_id,
terminal_steps.agent_id_to_index[local_id],
)
# Clear the last seen group obs when agents die.
self._clear_teammate_obs(global_id)

for ongoing_step in decision_steps.values():
local_id = ongoing_step.agent_id
global_id = get_global_agent_id(worker_id, local_id)
self._gather_teammate_obs(ongoing_step, global_id)
self._gather_teammate_obs(ongoing_step, global_id, worker_id)
ongoing_step, global_id, decision_steps.agent_id_to_index[local_id]
ongoing_step,
global_id,
worker_id,
decision_steps.agent_id_to_index[local_id],
)
for _gid in action_global_agent_ids:

)
def _gather_teammate_obs(
self, step: Union[TerminalStep, DecisionStep], global_id: str
self, step: Union[TerminalStep, DecisionStep], global_id: str, worker_id: int
self.last_group_obs[step.team_manager_id][
global_manager_id = get_global_manager_id(
worker_id, step.team_manager_id
)
self.last_group_obs[global_manager_id][
self.current_group_obs[step.team_manager_id][global_id] = step.obs
self.current_group_obs[global_manager_id][global_id] = step.obs
to_delete = []
self._safe_delete(_team_group, _manager_id)
to_delete.append(_manager_id)
for _manager_id in to_delete:
self._safe_delete(self.current_group_obs, _manager_id)
to_delete = []
self._safe_delete(_team_group, _manager_id)
to_delete.append(_manager_id)
for _manager_id in to_delete:
self._safe_delete(self.last_group_obs, _manager_id)
self, step: Union[TerminalStep, DecisionStep], global_id: str, index: int
self,
step: Union[TerminalStep, DecisionStep],
global_id: str,
worker_id: int,
index: int,
) -> None:
terminated = isinstance(step, TerminalStep)
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))

# Assemble teammate_obs. If none saved, then it will be an empty list.
collab_obs = []
for _id, _obs in self.last_group_obs[step.team_manager_id].items():
global_manager_id = get_global_manager_id(worker_id, step.team_manager_id)
for _id, _obs in self.last_group_obs[global_manager_id].items():
if _id != global_id:
collab_obs.append(_obs)

):
next_obs = step.obs
next_collab_obs = []
for _id, _exp in self.current_group_obs[step.team_manager_id].items():
global_manager_id = get_global_manager_id(
worker_id, step.team_manager_id
)
for _id, _exp in self.current_group_obs[global_manager_id].items():
if _id != global_id:
next_collab_obs.append(_exp)

7
ml-agents/mlagents/trainers/behavior_id_utils.py


Create an agent id that is unique across environment workers using the worker_id.
"""
return f"${worker_id}-{agent_id}"
def get_global_manager_id(worker_id: int, manager_id: int) -> str:
"""
Create an agent id that is unique across environment workers using the worker_id.
"""
return f"#{worker_id}-{manager_id}"
正在加载...
取消
保存