浏览代码

[refactor] Move checkpoint saving into trainer (#4034)

/MLA-1734-demo-provider
GitHub 5 年前
当前提交
09853e13
共有 20 个文件被更改,包括 96 次插入54 次删除
  1. 1
      com.unity.ml-agents/CHANGELOG.md
  2. 4
      docs/Migrating.md
  3. 3
      docs/Training-Configuration-File.md
  4. 1
      docs/Training-ML-Agents.md
  5. 3
      docs/Using-Tensorboard.md
  6. 7
      ml-agents/mlagents/trainers/cli_utils.py
  7. 2
      ml-agents/mlagents/trainers/ghost/trainer.py
  8. 1
      ml-agents/mlagents/trainers/learn.py
  9. 1
      ml-agents/mlagents/trainers/ppo/trainer.py
  10. 1
      ml-agents/mlagents/trainers/sac/trainer.py
  11. 2
      ml-agents/mlagents/trainers/settings.py
  12. 8
      ml-agents/mlagents/trainers/tests/test_learn.py
  13. 1
      ml-agents/mlagents/trainers/tests/test_ppo.py
  14. 49
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  15. 1
      ml-agents/mlagents/trainers/tests/test_sac.py
  16. 2
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  17. 2
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py
  18. 38
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  19. 3
      ml-agents/mlagents/trainers/trainer/trainer.py
  20. 20
      ml-agents/mlagents/trainers/trainer_controller.py

1
com.unity.ml-agents/CHANGELOG.md


#### ml-agents / ml-agents-envs / gym-unity (Python)
- Unity Player logs are now written out to the results directory. (#3877)
- Run configuration YAML files are written out to the results directory at the end of the run. (#3815)
- The `--save-freq` CLI option has been removed, and replaced by a `checkpoint_interval` option in the trainer configuration YAML. (#4034)
- When trying to load/resume from a checkpoint created with an earlier verison of ML-Agents,
a warning will be thrown. (#4035)
### Bug Fixes

4
docs/Migrating.md


- `use_visual` and `allow_multiple_visual_obs` in the `UnityToGymWrapper` constructor
were replaced by `allow_multiple_obs` which allows one or more visual observations and
vector observations to be used simultaneously.
- `--save-freq` has been removed from the CLI and is now configurable in the trainer configuration
file.
- `--lesson` has been removed from the CLI. Lessons will resume when using `--resume`.
To start at a different lesson, modify your Curriculum configuration.

- If you use the `UnityToGymWrapper`, remove `use_visual` and `allow_multiple_visual_obs`
from the constructor and add `allow_multiple_obs = True` if the environment contains either
both visual and vector observations or multiple visual observations.
- If you were setting `--save-freq` in the CLI, add a `checkpoint_interval` value in your
trainer configuration, and set it equal to `save-freq * n_agents_in_scene`.
## Migrating from 0.15 to Release 1

3
docs/Training-Configuration-File.md


| `summary_freq` | (default = `50000`) Number of experiences that needs to be collected before generating and displaying training statistics. This determines the granularity of the graphs in Tensorboard. |
| `time_horizon` | (default = `64`) How many steps of experience to collect per-agent before adding it to the experience buffer. When this limit is reached before the end of an episode, a value estimate is used to predict the overall expected reward from the agent's current state. As such, this parameter trades off between a less biased, but higher variance estimate (long time horizon) and more biased, but less varied estimate (short time horizon). In cases where there are frequent rewards within an episode, or episodes are prohibitively large, a smaller number can be more ideal. This number should be large enough to capture all the important behavior within a sequence of an agent's actions. <br><br> Typical range: `32` - `2048` |
| `max_steps` | (default = `500000`) Total number of steps (i.e., observation collected and action taken) that must be taken in the environment (or across all environments if using multiple in parallel) before ending the training process. If you have multiple agents with the same behavior name within your environment, all steps taken by those agents will contribute to the same `max_steps` count. <br><br>Typical range: `5e5` - `1e7` |
| `keep_checkpoints` | (default = `5`) The maximum number of model checkpoints to keep. Checkpoints are saved after the number of steps specified by the save-freq option. Once the maximum number of checkpoints has been reached, the oldest checkpoint is deleted when saving a new checkpoint. |
| `keep_checkpoints` | (default = `5`) The maximum number of model checkpoints to keep. Checkpoints are saved after the number of steps specified by the checkpoint_interval option. Once the maximum number of checkpoints has been reached, the oldest checkpoint is deleted when saving a new checkpoint. |
| `checkpoint_interval` | (default = `500000`) The number of experiences collected between each checkpoint by the trainer. A maximum of `keep_checkpoints` checkpoints are saved before old ones are deleted. |
| `init_path` | (default = None) Initialize trainer from a previously saved model. Note that the prior run should have used the same trainer configurations as the current run, and have been saved with the same version of ML-Agents. <br><br>You should provide the full path to the folder where the checkpoints were saved, e.g. `./models/{run-id}/{behavior_name}`. This option is provided in case you want to initialize different behaviors from different runs; in most cases, it is sufficient to use the `--initialize-from` CLI parameter to initialize all models from the same run. |
| `threaded` | (default = `true`) By default, model updates can happen while the environment is being stepped. This violates the [on-policy](https://spinningup.openai.com/en/latest/user/algorithms.html#the-on-policy-algorithms) assumption of PPO slightly in exchange for a training speedup. To maintain the strict on-policyness of PPO, you can disable parallel updates by setting `threaded` to `false`. There is usually no reason to turn `threaded` off for SAC. |
| `hyperparameters -> learning_rate` | (default = `3e-4`) Initial learning rate for gradient descent. Corresponds to the strength of each gradient descent update step. This should typically be decreased if training is unstable, and the reward does not consistently increase. <br><br>Typical range: `1e-5` - `1e-3` |

1
docs/Training-ML-Agents.md


time_horizon: 64
summary_freq: 10000
keep_checkpoints: 5
checkpoint_interval: 50000
threaded: true
init_path: null

3
docs/Using-Tensorboard.md


The TensorBoard window also provides options for how to display and smooth
graphs.
When you run the training program, `mlagents-learn`, you can use the
`--save-freq` option to specify how frequently to save the statistics.
## The ML-Agents Toolkit training statistics
The ML-Agents training program saves the following statistics:

7
ml-agents/mlagents/trainers/cli_utils.py


action=DetectDefault,
)
argparser.add_argument(
"--save-freq",
default=50000,
type=int,
help="How often (in steps) to save the model during training",
action=DetectDefault,
)
argparser.add_argument(
"--seed",
default=-1,
type=int,

2
ml-agents/mlagents/trainers/ghost/trainer.py


except AgentManagerQueue.Empty:
pass
self.next_summary_step = self.trainer.next_summary_step
self._next_summary_step = self.trainer._next_summary_step
self.trainer.advance()
if self.get_step - self.last_team_change > self.steps_to_train_team:
self.controller.change_training_team(self.get_step)

1
ml-agents/mlagents/trainers/learn.py


trainer_factory,
write_path,
checkpoint_settings.run_id,
checkpoint_settings.save_freq,
maybe_meta_curriculum,
not checkpoint_settings.inference,
run_seed,

1
ml-agents/mlagents/trainers/ppo/trainer.py


self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly
self.step = policy.get_current_step()
self.next_summary_step = self._get_next_summary_step()
def get_policy(self, name_behavior_id: str) -> TFPolicy:
"""

1
ml-agents/mlagents/trainers/sac/trainer.py


self.reward_signal_update_steps = int(
max(1, self.step / self.reward_signal_steps_per_update)
)
self.next_summary_step = self._get_next_summary_step()
def get_policy(self, name_behavior_id: str) -> TFPolicy:
"""

2
ml-agents/mlagents/trainers/settings.py


init_path: Optional[str] = None
output_path: str = "default"
keep_checkpoints: int = 5
checkpoint_interval: int = 500000
max_steps: int = 500000
time_horizon: int = 64
summary_freq: int = 50000

@attr.s(auto_attribs=True)
class CheckpointSettings:
save_freq: int = parser.get_default("save_freq")
run_id: str = parser.get_default("run_id")
initialize_from: str = parser.get_default("initialize_from")
load_model: bool = parser.get_default("load_model")

8
ml-agents/mlagents/trainers/tests/test_learn.py


seed: 9870
checkpoint_settings:
run_id: uselessrun
save_freq: 654321
debug: false
"""

trainer_factory_mock.return_value,
"results/ppo",
"ppo",
50000,
None,
True,
0,

assert opt.checkpoint_settings.resume is False
assert opt.checkpoint_settings.inference is False
assert opt.checkpoint_settings.run_id == "ppo"
assert opt.checkpoint_settings.save_freq == 50000
assert opt.env_settings.seed == -1
assert opt.env_settings.base_port == 5005
assert opt.env_settings.num_envs == 1

"--resume",
"--inference",
"--run-id=myawesomerun",
"--save-freq=123456",
"--seed=7890",
"--train",
"--base-port=4004",

assert opt.env_settings.env_path == "./myenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.run_id == "myawesomerun"
assert opt.checkpoint_settings.save_freq == 123456
assert opt.env_settings.seed == 7890
assert opt.env_settings.base_port == 4004
assert opt.env_settings.num_envs == 2

assert opt.env_settings.env_path == "./oldenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.run_id == "uselessrun"
assert opt.checkpoint_settings.save_freq == 654321
assert opt.env_settings.seed == 9870
assert opt.env_settings.base_port == 4001
assert opt.env_settings.num_envs == 4

"--resume",
"--inference",
"--run-id=myawesomerun",
"--save-freq=123456",
"--seed=7890",
"--train",
"--base-port=4004",

assert opt.env_settings.env_path == "./myenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.run_id == "myawesomerun"
assert opt.checkpoint_settings.save_freq == 123456
assert opt.env_settings.seed == 7890
assert opt.env_settings.base_port == 4004
assert opt.env_settings.num_envs == 2

1
ml-agents/mlagents/trainers/tests/test_ppo.py


# Make sure the summary steps were loaded properly
assert trainer.get_step == 2000
assert trainer.next_summary_step > 2000
# Test incorrect class of policy
policy = mock.Mock()

49
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


def create_rl_trainer():
mock_brainparams = create_mock_brain()
trainer = FakeTrainer(mock_brainparams, TrainerSettings(max_steps=100), True, 0)
trainer = FakeTrainer(
mock_brainparams,
TrainerSettings(max_steps=100, checkpoint_interval=10, summary_freq=20),
True,
0,
)
trainer.set_is_policy_updating(True)
return trainer

# Check that the buffer has been cleared
assert not trainer.should_still_train
assert mocked_clear_update_buffer.call_count > 0
@mock.patch("mlagents.trainers.trainer.trainer.Trainer.save_model")
@mock.patch("mlagents.trainers.trainer.trainer.StatsReporter.write_stats")
def test_summary_checkpoint(mock_write_summary, mock_save_model):
trainer = create_rl_trainer()
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")
trainer.subscribe_trajectory_queue(trajectory_queue)
trainer.publish_policy_queue(policy_queue)
time_horizon = 10
summary_freq = trainer.trainer_settings.summary_freq
checkpoint_interval = trainer.trainer_settings.checkpoint_interval
trajectory = mb.make_fake_trajectory(
length=time_horizon,
max_step_complete=True,
vec_obs_size=1,
num_vis_obs=0,
action_space=[2],
)
# Check that we can turn off the trainer and that the buffer is cleared
num_trajectories = 5
for _ in range(0, num_trajectories):
trajectory_queue.put(trajectory)
trainer.advance()
# Check that there is stuff in the policy queue
policy_queue.get_nowait()
# Check that we have called write_summary the appropriate number of times
calls = [
mock.call(step)
for step in range(summary_freq, num_trajectories * time_horizon, summary_freq)
]
mock_write_summary.assert_has_calls(calls, any_order=True)
calls = [
mock.call(trainer.brain_name)
for step in range(
checkpoint_interval, num_trajectories * time_horizon, checkpoint_interval
)
]
mock_save_model.assert_has_calls(calls, any_order=True)

1
ml-agents/mlagents/trainers/tests/test_sac.py


# Make sure the summary steps were loaded properly
assert trainer.get_step == 2000
assert trainer.next_summary_step > 2000
# Test incorrect class of policy
policy = mock.Mock()

2
ml-agents/mlagents/trainers/tests/test_simple_rl.py


# Create controller and begin training.
with tempfile.TemporaryDirectory() as dir:
run_id = "id"
save_freq = 99999
seed = 1337
StatsReporter.writers.clear() # Clear StatsReporters so we don't write to file
debug_writer = DebugWriter()

training_seed=seed,
sampler_manager=SamplerManager(None),
resampling_interval=None,
save_freq=save_freq,
)
# Begin training

2
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


trainer_factory=trainer_factory_mock,
output_path="test_model_path",
run_id="test_run_id",
save_freq=100,
meta_curriculum=None,
train=True,
training_seed=99,

trainer_factory=trainer_factory_mock,
output_path="",
run_id="1",
save_freq=1,
meta_curriculum=None,
train=True,
training_seed=seed,

38
ml-agents/mlagents/trainers/trainer/rl_trainer.py


import abc
import time
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trainer import Trainer

from mlagents.trainers.stats import StatsPropertyType
RewardSignalResults = Dict[str, RewardSignalResult]
logger = get_logger(__name__)
class RLTrainer(Trainer): # pylint: disable=abstract-method

self._stats_reporter.add_property(
StatsPropertyType.HYPERPARAMETERS, self.trainer_settings.as_dict()
)
self._next_save_step = 0
self._next_summary_step = 0
def end_episode(self) -> None:
"""

:param n_steps: number of steps to increment the step count by
"""
self.step += n_steps
self.next_summary_step = self._get_next_summary_step()
self._next_summary_step = self._get_next_interval_step(self.summary_freq)
self._next_save_step = self._get_next_interval_step(
self.trainer_settings.checkpoint_interval
)
def _get_next_summary_step(self) -> int:
def _get_next_interval_step(self, interval: int) -> int:
Get the next step count that should result in a summary write.
Get the next step count that should result in an action.
:param interval: The interval between actions.
return self.step + (self.summary_freq - self.step % self.summary_freq)
return self.step + (interval - self.step % interval)
def _write_summary(self, step: int) -> None:
"""

:param trajectory: The Trajectory tuple containing the steps to be processed.
"""
self._maybe_write_summary(self.get_step + len(trajectory.steps))
self._maybe_save_model(self.get_step + len(trajectory.steps))
self._increment_step(len(trajectory.steps), trajectory.behavior_id)
def _maybe_write_summary(self, step_after_process: int) -> None:

:param step_after_process: the step count after processing the next trajectory.
"""
if step_after_process >= self.next_summary_step and self.get_step != 0:
self._write_summary(self.next_summary_step)
if self._next_summary_step == 0: # Don't write out the first one
self._next_summary_step = self._get_next_interval_step(self.summary_freq)
if step_after_process >= self._next_summary_step and self.get_step != 0:
self._write_summary(self._next_summary_step)
def _maybe_save_model(self, step_after_process: int) -> None:
"""
If processing the trajectory will make the step exceed the next model write,
save the model. This logic ensures models are written on the update step and not in between.
:param step_after_process: the step count after processing the next trajectory.
"""
if self._next_save_step == 0: # Don't save the first one
self._next_save_step = self._get_next_interval_step(
self.trainer_settings.checkpoint_interval
)
if step_after_process >= self._next_save_step and self.get_step != 0:
logger.info(f"Checkpointing model for {self.brain_name}.")
self.save_model(self.brain_name)
def advance(self) -> None:
"""

3
ml-agents/mlagents/trainers/trainer/trainer.py


from collections import deque
from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed
from mlagents.model_serialization import export_policy_model, SerializationSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.stats import StatsReporter

self.trajectory_queues: List[AgentManagerQueue[Trajectory]] = []
self.step: int = 0
self.summary_freq = self.trainer_settings.summary_freq
self.next_summary_step = self.summary_freq
@property
def stats_reporter(self):

"""
return self._reward_buffer
@timed
def save_model(self, name_behavior_id: str) -> None:
"""
Saves the model

20
ml-agents/mlagents/trainers/trainer_controller.py


trainer_factory: TrainerFactory,
output_path: str,
run_id: str,
save_freq: int,
meta_curriculum: Optional[MetaCurriculum],
train: bool,
training_seed: int,

:param output_path: Path to save the model.
:param summaries_dir: Folder to save training summaries.
:param run_id: The sub-directory name for model and summary statistics
:param save_freq: Frequency at which to save model
:param meta_curriculum: MetaCurriculum object which stores information about all curricula.
:param train: Whether to train model, or only run inference.
:param training_seed: Seed to use for Numpy and Tensorflow random number generation.

self.output_path = output_path
self.logger = get_logger(__name__)
self.run_id = run_id
self.save_freq = save_freq
self.train_model = train
self.meta_curriculum = meta_curriculum
self.sampler_manager = sampler_manager

sampled_reset_param.update(new_meta_curriculum_config)
env.reset(config=sampled_reset_param)
def _should_save_model(self, global_step: int) -> bool:
return (
global_step % self.save_freq == 0 and global_step != 0 and self.train_model
)
def _not_done_training(self) -> bool:
return (
any(t.should_still_train for t in self.trainers.values())

for _ in range(n_steps):
global_step += 1
self.reset_env_if_ready(env_manager, global_step)
if self._should_save_model(global_step):
self._save_model()
# Final save Tensorflow model
if global_step != 0 and self.train_model:
self._save_model()
except (
KeyboardInterrupt,
UnityCommunicationException,

self.join_threads()
if self.train_model:
self._save_model_when_interrupted()
self.logger.info(
"Learning was interrupted. Please wait while the graph is generated."
)
if isinstance(ex, KeyboardInterrupt) or isinstance(
ex, UnityCommunicatorStoppedException
):

raise ex
finally:
if self.train_model:
self._save_model()
self._export_graph()
def end_trainer_episodes(

正在加载...
取消
保存