浏览代码

Ability to disable threading

/develop/sac-apex
Ervin Teng 5 年前
当前提交
ed06f37c
共有 4 个文件被更改,包括 35 次插入14 次删除
  1. 4
      ml-agents/mlagents/trainers/ghost/trainer.py
  2. 20
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  3. 4
      ml-agents/mlagents/trainers/trainer/trainer.py
  4. 21
      ml-agents/mlagents/trainers/trainer_controller.py

4
ml-agents/mlagents/trainers/ghost/trainer.py


self.change_current_elo(change)
self._stats_reporter.add_stat("Self-play/ELO", self.current_elo)
def advance(self) -> None:
def advance(self, empty_queue: bool = False) -> None:
"""
Steps the trainer, passing trajectories to wrapped trainer and calling trainer advance
"""

# adds to wrapped trainers queue
internal_trajectory_queue.put(t)
self._process_trajectory(t)
if not empty_queue:
break
except AgentManagerQueue.Empty:
pass
else:

20
ml-agents/mlagents/trainers/trainer/rl_trainer.py


if step_after_process >= self.next_summary_step and self.get_step != 0:
self._write_summary(self.next_summary_step)
def advance(self) -> None:
def advance(self, empty_queue: bool = False) -> None:
:param empty_queue: Whether or not to empty the queue when called. For synchronous
operation, we need to do so to avoid the queue filling up.
try:
t = traj_queue.get(0.05)
self._process_trajectory(t)
except AgentManagerQueue.Empty:
break
# We grab at most the maximum length of the queue.
# This ensures that even if the queue is being filled faster than it is
# being emptied, the trajectories in the queue are on-policy.
for _ in range(traj_queue.maxlen):
try:
t = traj_queue.get(0.05)
self._process_trajectory(t)
if not empty_queue:
break
except AgentManagerQueue.Empty:
break
if self.should_still_train:
if self._is_ready_update():
with hierarchical_timer("_update_policy"):

4
ml-agents/mlagents/trainers/trainer/trainer.py


pass
@abc.abstractmethod
def advance(self) -> None:
def advance(self, empty_queue: bool = False) -> None:
:param empty_queue: Whether or not to empty the queue when called. For synchronous
operation, we need to do so to avoid the queue filling up.
"""
pass

21
ml-agents/mlagents/trainers/trainer_controller.py


training_seed: int,
sampler_manager: SamplerManager,
resampling_interval: Optional[int],
threaded: bool = True,
):
"""
:param model_path: Path to save the model.

:param training_seed: Seed to use for Numpy and Tensorflow random number generation.
:param sampler_manager: SamplerManager object handles samplers for resampling the reset parameters.
:param resampling_interval: Specifies number of simulation steps after which reset parameters are resampled.
:param threaded: Whether or not to run trainers in a separate thread. Disable for testing/debugging.
"""
self.trainers: Dict[str, Trainer] = {}
self.brain_name_to_identifier: Dict[str, Set] = defaultdict(set)

self.sampler_manager = sampler_manager
self.resampling_interval = resampling_interval
self.threaded = threaded
self.trainer_threads: List[threading.Thread] = []
self.kill_trainers = False
np.random.seed(training_seed)

trainer.publish_policy_queue(agent_manager.policy_queue)
trainer.subscribe_trajectory_queue(agent_manager.trajectory_queue)
# Start trainer thread
trainerthread = threading.Thread(
target=self.trainer_update_func, args=(trainer,), daemon=True
)
trainerthread.start()
self.trainer_threads.append(trainerthread)
if self.threaded:
# Start trainer thread
trainerthread = threading.Thread(
target=self.trainer_update_func, args=(trainer,), daemon=True
)
trainerthread.start()
self.trainer_threads.append(trainerthread)
def _create_trainers_and_managers(
self, env_manager: EnvManager, behavior_ids: Set[str]

self.trainers[brain_name].stats_reporter.set_stat(
"Environment/Lesson", curr.lesson_num
)
if not self.threaded:
with hierarchical_timer("trainer_advance"):
for trainer in self.trainers.values():
trainer.advance(empty_queue=True)
return num_steps

正在加载...
取消
保存