浏览代码

Make trainer in separate threads

/develop/sac-apex
Ervin Teng 5 年前
当前提交
3deb8e30
共有 3 个文件被更改,包括 44 次插入18 次删除
  1. 24
      ml-agents/mlagents/trainers/agent_processor.py
  2. 12
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  3. 26
      ml-agents/mlagents/trainers/trainer_controller.py

24
ml-agents/mlagents/trainers/agent_processor.py


import sys
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Set
from collections import defaultdict, Counter, deque
from typing import List, Dict, TypeVar, Generic, Tuple, Set
from collections import defaultdict, Counter
import queue
from mlagents_envs.base_env import BatchedStepResult, StepResult
from mlagents.trainers.trajectory import Trajectory, AgentExperience

separately from an AgentManager.
"""
self.maxlen: int = maxlen
self.queue: Deque[T] = deque(maxlen=self.maxlen)
self.queue: queue.Queue = queue.Queue(maxsize=maxlen)
return len(self.queue) == 0
return self.queue.empty()
return self.queue.popleft()
except IndexError:
return self.queue.get_nowait()
except queue.Empty:
raise self.Empty("The AgentManagerQueue is empty.")
def get(self, timeout: float) -> T:
"""
Blocking get
"""
try:
return self.queue.get(timeout=timeout)
except queue.Empty:
self.queue.append(item)
self.queue.put(item)
class AgentManager(AgentProcessor):

12
ml-agents/mlagents/trainers/trainer/rl_trainer.py


def advance(self) -> None:
"""
Steps the trainer, taking in trajectories and updates if ready.
Will block and wait if there are no trajectories.
"""
with hierarchical_timer("process_trajectory"):
for traj_queue in self.trajectory_queues:

for _ in range(traj_queue.maxlen):
try:
t = traj_queue.get_nowait()
self._process_trajectory(t)
except AgentManagerQueue.Empty:
break
try:
t = traj_queue.get(0.05)
self._process_trajectory(t)
except AgentManagerQueue.Empty:
break
if self.should_still_train:
if self._is_ready_update():
with hierarchical_timer("_update_policy"):

26
ml-agents/mlagents/trainers/trainer_controller.py


import os
import sys
import logging
from typing import Dict, Optional, Set
import threading
from typing import Dict, Optional, Set, List
from collections import defaultdict
import numpy as np

self.meta_curriculum = meta_curriculum
self.sampler_manager = sampler_manager
self.resampling_interval = resampling_interval
self.trainer_threads: List[threading.Thread] = []
self.kill_trainers = False
np.random.seed(training_seed)
tf.set_random_seed(training_seed)

trainer.publish_policy_queue(agent_manager.policy_queue)
trainer.subscribe_trajectory_queue(agent_manager.trajectory_queue)
# Start trainer thread
trainerthread = threading.Thread(
target=self.trainer_update_func, args=(trainer,), daemon=True
)
trainerthread.start()
self.trainer_threads.append(trainerthread)
def _create_trainers_and_managers(
self, env_manager: EnvManager, behavior_ids: Set[str]

if global_step != 0 and self.train_model:
self._save_model()
except (KeyboardInterrupt, UnityCommunicationException):
self.kill_trainers = True
if self.train_model:
self._save_model_when_interrupted()
pass

"Environment/Lesson", curr.lesson_num
)
# Advance trainers. This can be done in a separate loop in the future.
with hierarchical_timer("trainer_advance"):
for trainer in self.trainers.values():
trainer.advance()
# # Advance trainers. This can be done in a separate loop in the future.
# with hierarchical_timer("trainer_advance"):
# for trainer in self.trainers.values():
# trainer.advance()
def trainer_update_func(self, trainer: Trainer) -> None:
while not self.kill_trainers:
with hierarchical_timer("trainer_advance"):
trainer.advance()
正在加载...
取消
保存