浏览代码

added team_change as a yaml config

/asymm-envs
Andrew Cohen 4 年前
当前提交
19552661
共有 2 个文件被更改,包括 25 次插入20 次删除
  1. 35
      ml-agents/mlagents/trainers/ghost/controller.py
  2. 10
      ml-agents/mlagents/trainers/ghost/trainer.py

35
ml-agents/mlagents/trainers/ghost/controller.py


which corresponds to the number of trainer steps between changing learning teams.
"""
def __init__(self, swap_interval: int, maxlen: int = 10):
def __init__(self, maxlen: int = 10):
:param swap_interval: Number of trainer steps between changing learning teams.
self._swap_interval = swap_interval
self._last_swap: Dict[int, int] = {}
# Dict from team id to GhostTrainer
# Dict from team id to GhostTrainer for ELO calculation
self._ghost_trainers: Dict[int, GhostTrainer] = {}
def subscribe_team_id(self, team_id: int, trainer: GhostTrainer) -> None:

"""
if team_id not in self._ghost_trainers:
self._ghost_trainers[team_id] = trainer
self._last_swap[team_id] = 0
def get_learning_team(self, step: int) -> int:
def get_learning_team(self) -> int:
Returns the current learning team. If 'swap_interval' steps have elapsed, the current
learning team is added to the end of the queue and then updated with the next in line.
:param step: Current step of the trainer.
Returns the current learning team.
if step >= self._swap_interval + self._last_swap[self._learning_team]:
self._last_swap[self._learning_team] = step
self._queue.append(self._learning_team)
self._learning_team = self._queue.popleft()
logger.debug(
"Learning team {} swapped on step {}".format(
self._learning_team, self._last_swap
)
)
def finish_training(self, step: int) -> None:
"""
The current learning team is added to the end of the queue and then updated with the
next in line.
:param step: The step of the trainer for debugging
"""
self._queue.append(self._learning_team)
self._learning_team = self._queue.popleft()
logger.debug(
"Learning team {} swapped on step {}".format(self._learning_team, step)
)
# Adapted from https://github.com/Unity-Technologies/ml-agents/pull/1975 and
# https://metinmediamath.wordpress.com/2013/11/27/how-to-calculate-the-elo-rating-including-example/

10
ml-agents/mlagents/trainers/ghost/trainer.py


)
self.steps_between_save = self_play_parameters.get("save_steps", 20000)
self.steps_between_swap = self_play_parameters.get("swap_steps", 20000)
self.steps_to_train_team = self_play_parameters.get("team_change", 100000)
# Counts the The number of steps of the ghost policies. Snapshot swapping
# depends on this counter whereas snapshot saving and team switching depends
# on the wrapped. This ensures that all teams train for the same number of trainer

self.wrapped_trainer_team: int = None
self.last_save: int = 0
self.last_swap: int = 0
self.last_team_change: int = 0
# Chosen because it is the initial ELO in Chess
self.initial_elo: float = self_play_parameters.get("initial_elo", 1200.0)

self.next_summary_step = self.trainer.next_summary_step
self.trainer.advance()
if self.get_step - self.last_team_change > self.steps_to_train_team:
self.controller.finish_training()
self.last_team_change = self.get_step
next_learning_team = self.controller.get_learning_team()
# CASE 1: Current learning team is managed by this GhostTrainer.
# If the learning team changes, the following loop over queues will push the
# new policy into the policy queue for the new learning agent if

# pushing fixed snapshots
# Case 3: No team change. The if statement just continues to push the policy
# into the correct queue (or not if not learning team).
next_learning_team = self.controller.get_learning_team(self.get_step)
for brain_name in self._internal_policy_queues:
internal_policy_queue = self._internal_policy_queues[brain_name]
try:

正在加载...
取消
保存