浏览代码

update agent processor to use group id

/develop/superpush/int
Ruo-Ping Dong 4 年前
当前提交
d7ade5c3
共有 2 个文件被更改,包括 17 次插入24 次删除
  1. 37
      ml-agents/mlagents/trainers/agent_processor.py
  2. 4
      ml-agents/mlagents/trainers/behavior_id_utils.py

37
ml-agents/mlagents/trainers/agent_processor.py


from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.behavior_id_utils import (
get_global_agent_id,
get_global_manager_id,
)
from mlagents.trainers.behavior_id_utils import get_global_agent_id, get_global_group_id
T = TypeVar("T")

) -> None:
stored_decision_step, idx = self.last_step_result.get(global_id, (None, None))
if stored_decision_step is not None:
if step.team_manager_id > 0:
global_manager_id = get_global_manager_id(
worker_id, step.team_manager_id
)
self.last_group_obs[global_manager_id][
if step.group_id > 0:
global_group_id = get_global_group_id(worker_id, step.group_id)
self.last_group_obs[global_group_id][
self.current_group_obs[global_manager_id][global_id] = step.obs
self.current_group_obs[global_group_id][global_id] = step.obs
for _manager_id in list(self.current_group_obs.keys()):
_team_group = self.current_group_obs[_manager_id]
for _group_id in list(self.current_group_obs.keys()):
_team_group = self.current_group_obs[_group_id]
self._safe_delete(self.current_group_obs, _manager_id)
for _manager_id in list(self.last_group_obs.keys()):
_team_group = self.last_group_obs[_manager_id]
self._safe_delete(self.current_group_obs, _group_id)
for _group_id in list(self.last_group_obs.keys()):
_team_group = self.last_group_obs[_group_id]
self._safe_delete(self.last_group_obs, _manager_id)
self._safe_delete(self.last_group_obs, _group_id)
def _process_step(
self,

# Assemble teammate_obs. If none saved, then it will be an empty list.
collab_obs = []
global_manager_id = get_global_manager_id(worker_id, step.team_manager_id)
for _id, _obs in self.last_group_obs[global_manager_id].items():
global_group_id = get_global_group_id(worker_id, step.group_id)
for _id, _obs in self.last_group_obs[global_group_id].items():
if _id != global_id:
collab_obs.append(_obs)

):
next_obs = step.obs
next_collab_obs = []
global_manager_id = get_global_manager_id(
worker_id, step.team_manager_id
)
for _id, _exp in self.current_group_obs[global_manager_id].items():
global_group_id = get_global_group_id(worker_id, step.group_id)
for _id, _exp in self.current_group_obs[global_group_id].items():
if _id != global_id:
next_collab_obs.append(_exp)

4
ml-agents/mlagents/trainers/behavior_id_utils.py


return f"${worker_id}-{agent_id}"
def get_global_manager_id(worker_id: int, manager_id: int) -> str:
def get_global_group_id(worker_id: int, group_id: int) -> str:
return f"${worker_id}-{manager_id}"
return f"${worker_id}-{group_id}"
正在加载...
取消
保存