浏览代码

top-level timers to see where time is going

/bug-failed-api-check
Chris Elion 5 年前
当前提交
0d65c600
共有 2 个文件被更改,包括 77 次插入68 次删除
  1. 142
      ml-agents/mlagents/trainers/learn.py
  2. 3
      ml-agents/mlagents/trainers/trainer_controller.py

142
ml-agents/mlagents/trainers/learn.py


from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.exception import UnityEnvironmentException
from mlagents_envs.timers import hierarchical_timer
from mlagents.logging_util import create_logger

:param run_seed: Random seed used for training.
:param run_options: Command line arguments for training.
"""
# Recognize and use docker volume if one is passed as an argument
if not options.docker_target_name:
model_path = f"./models/{options.run_id}"
summaries_dir = "./summaries"
else:
model_path = f"/{options.docker_target_name}/models/{options.run_id}"
summaries_dir = f"/{options.docker_target_name}/summaries"
port = options.base_port
with hierarchical_timer("run_training.setup"):
# Recognize and use docker volume if one is passed as an argument
if not options.docker_target_name:
model_path = f"./models/{options.run_id}"
summaries_dir = "./summaries"
else:
model_path = f"/{options.docker_target_name}/models/{options.run_id}"
summaries_dir = f"/{options.docker_target_name}/summaries"
port = options.base_port
# Configure CSV, Tensorboard Writers and StatsReporter
# We assume reward and episode length are needed in the CSV.
csv_writer = CSVWriter(
summaries_dir,
required_fields=["Environment/Cumulative Reward", "Environment/Episode Length"],
)
tb_writer = TensorboardWriter(summaries_dir)
gauge_write = GaugeWriter()
StatsReporter.add_writer(tb_writer)
StatsReporter.add_writer(csv_writer)
StatsReporter.add_writer(gauge_write)
# Configure CSV, Tensorboard Writers and StatsReporter
# We assume reward and episode length are needed in the CSV.
csv_writer = CSVWriter(
summaries_dir,
required_fields=[
"Environment/Cumulative Reward",
"Environment/Episode Length",
],
)
tb_writer = TensorboardWriter(summaries_dir)
gauge_write = GaugeWriter()
StatsReporter.add_writer(tb_writer)
StatsReporter.add_writer(csv_writer)
StatsReporter.add_writer(gauge_write)
if options.env_path is None:
port = UnityEnvironment.DEFAULT_EDITOR_PORT
env_factory = create_environment_factory(
options.env_path,
options.docker_target_name,
options.no_graphics,
run_seed,
port,
options.env_args,
)
engine_config = EngineConfig(
options.width,
options.height,
options.quality_level,
options.time_scale,
options.target_frame_rate,
)
env_manager = SubprocessEnvManager(env_factory, engine_config, options.num_envs)
maybe_meta_curriculum = try_create_meta_curriculum(
options.curriculum_config, env_manager, options.lesson
)
sampler_manager, resampling_interval = create_sampler_manager(
options.sampler_config, run_seed
)
trainer_factory = TrainerFactory(
options.trainer_config,
summaries_dir,
options.run_id,
model_path,
options.keep_checkpoints,
options.train_model,
options.load_model,
run_seed,
maybe_meta_curriculum,
options.multi_gpu,
)
# Create controller and begin training.
tc = TrainerController(
trainer_factory,
model_path,
summaries_dir,
options.run_id,
options.save_freq,
maybe_meta_curriculum,
options.train_model,
run_seed,
sampler_manager,
resampling_interval,
)
if options.env_path is None:
port = UnityEnvironment.DEFAULT_EDITOR_PORT
env_factory = create_environment_factory(
options.env_path,
options.docker_target_name,
options.no_graphics,
run_seed,
port,
options.env_args,
)
engine_config = EngineConfig(
options.width,
options.height,
options.quality_level,
options.time_scale,
options.target_frame_rate,
)
env_manager = SubprocessEnvManager(env_factory, engine_config, options.num_envs)
maybe_meta_curriculum = try_create_meta_curriculum(
options.curriculum_config, env_manager, options.lesson
)
sampler_manager, resampling_interval = create_sampler_manager(
options.sampler_config, run_seed
)
trainer_factory = TrainerFactory(
options.trainer_config,
summaries_dir,
options.run_id,
model_path,
options.keep_checkpoints,
options.train_model,
options.load_model,
run_seed,
maybe_meta_curriculum,
options.multi_gpu,
)
# Create controller and begin training.
tc = TrainerController(
trainer_factory,
model_path,
summaries_dir,
options.run_id,
options.save_freq,
maybe_meta_curriculum,
options.train_model,
run_seed,
sampler_manager,
resampling_interval,
)
# Begin training
try:
tc.start_learning(env_manager)

3
ml-agents/mlagents/trainers/trainer_controller.py


brain_names_to_measure_vals[brain_name] = measure_val
return brain_names_to_measure_vals
@timed
def _save_model(self):
"""
Saves current model to checkpoint folder.

"permissions are set correctly.".format(model_path)
)
@timed
def _reset_env(self, env: EnvManager) -> None:
"""Resets the environment.

for behavior_id in behavior_ids:
self._create_trainer_and_manager(env_manager, behavior_id)
@timed
def start_learning(self, env_manager: EnvManager) -> None:
self._create_model_path(self.model_path)
tf.reset_default_graph()

正在加载...
取消
保存