浏览代码

Add cc to ghost trainer

/comms-grad
Ervin Teng 4 年前
当前提交
7087b7b3
共有 2 个文件被更改,包括 26 次插入10 次删除
  1. 23
      ml-agents/mlagents/trainers/behavior_id_utils.py
  2. 13
      ml-agents/mlagents/trainers/ghost/trainer.py

23
ml-agents/mlagents/trainers/behavior_id_utils.py


from typing import NamedTuple
from typing import NamedTuple, Optional
from urllib.parse import urlparse, parse_qs

)
def create_name_behavior_id(name: str, team_id: int) -> str:
def create_name_behavior_id(
name: str, team_id: Optional[int] = None, group_id: Optional[int] = None
) -> str:
"""
Reconstructs fully qualified behavior name from name and team_id
:param name: brain name
:param team_id: team ID
:return: name_behavior_id
Reconstructs fully qualified behavior name from name and team_id
:param name: brain name
:param team_id: team ID
:return: name_behavior_id
"""
return name + "?team=" + str(team_id)
final_name = name
if team_id is not None:
final_name += f"?team={team_id}"
if group_id is not None:
final_name += f"&group={group_id}"
return final_name
def get_global_agent_id(worker_id: int, agent_id: int) -> str:

13
ml-agents/mlagents/trainers/ghost/trainer.py


:param parsed_behavior_id: Behavior ID that the policy should belong to.
:param policy: Policy to associate with name_behavior_id.
"""
name_behavior_id = parsed_behavior_id.behavior_id
self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id
name_behavior_id = create_name_behavior_id(
parsed_behavior_id.brain_name, team_id=parsed_behavior_id.team_id
)
# Add policy only based on the team id, not the group id
self._name_to_parsed_behavior_id[
parsed_behavior_id.behavior_id
] = parsed_behavior_id
self.policies[name_behavior_id] = policy
def get_policy(self, name_behavior_id: str) -> Policy:

:return: Policy associated with name_behavior_id
"""
parsed_behavior_id = BehaviorIdentifiers.from_name_behavior_id(name_behavior_id)
name_behavior_id = create_name_behavior_id(
parsed_behavior_id.brain_name, team_id=parsed_behavior_id.team_id
)
return self.policies[name_behavior_id]
def _save_snapshot(self) -> None:

正在加载...
取消
保存