浏览代码

[refactor] Store and restore state along with checkpoints (#4025)

/MLA-1734-demo-provider
GitHub 5 年前
当前提交
f5435876
共有 13 个文件被更改,包括 242 次插入46 次删除
  1. 12
      com.unity.ml-agents/CHANGELOG.md
  2. 2
      docs/Migrating.md
  3. 8
      docs/Training-ML-Agents.md
  4. 2
      ml-agents/README.md
  5. 7
      ml-agents/mlagents/trainers/cli_utils.py
  6. 22
      ml-agents/mlagents/trainers/learn.py
  7. 25
      ml-agents/mlagents/trainers/meta_curriculum.py
  8. 1
      ml-agents/mlagents/trainers/settings.py
  9. 7
      ml-agents/mlagents/trainers/tests/test_learn.py
  10. 23
      ml-agents/mlagents/trainers/tests/test_meta_curriculum.py
  11. 4
      ml-agents/mlagents/trainers/trainer_controller.py
  12. 60
      ml-agents/mlagents/trainers/tests/test_training_status.py
  13. 115
      ml-agents/mlagents/trainers/training_status.py

12
com.unity.ml-agents/CHANGELOG.md


- `use_visual` and `allow_multiple_visual_obs` in the `UnityToGymWrapper` constructor
were replaced by `allow_multiple_obs` which allows one or more visual observations and
vector observations to be used simultaneously. (#3981) Thank you @shakenes !
### Minor Changes
#### com.unity.ml-agents (C#)
- `ObservableAttribute` was added. Adding the attribute to fields or properties on an Agent will allow it to generate
observations via reflection. (#3925, #4006)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Curriculum and Parameter Randomization configurations have been merged
into the main training configuration file. Note that this means training
configuration files are now environment-specific. (#3791)

directory. (#3829)
- When using Curriculum, the current lesson will resume if training is quit and resumed. As such,
the `--lesson` CLI option has been removed. (#4025)
### Minor Changes
#### com.unity.ml-agents (C#)
- `ObservableAttribute` was added. Adding the attribute to fields or properties on an Agent will allow it to generate
observations via reflection. (#3925, #4006)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Unity Player logs are now written out to the results directory. (#3877)
- Run configuration YAML files are written out to the results directory at the end of the run. (#3815)
- When trying to load/resume from a checkpoint created with an earlier verison of ML-Agents,

2
docs/Migrating.md


- `use_visual` and `allow_multiple_visual_obs` in the `UnityToGymWrapper` constructor
were replaced by `allow_multiple_obs` which allows one or more visual observations and
vector observations to be used simultaneously.
- `--lesson` has been removed from the CLI. Lessons will resume when using `--resume`.
To start at a different lesson, modify your Curriculum configuration.
### Steps to Migrate
- To upgrade your configuration files, an upgrade script has been provided. Run `python config/update_config.py

8
docs/Training-ML-Agents.md


mlagents-learn config/ppo/WallJump_curriculum.yaml --run-id=wall-jump-curriculum
```
We can then keep track of the current lessons and progresses via TensorBoard.
**Note**: If you are resuming a training session that uses curriculum, please
pass the number of the last-reached lesson using the `--lesson` flag when
running `mlagents-learn`.
We can then keep track of the current lessons and progresses via TensorBoard. If you've terminated
the run, you can resume it using `--resume` and lesson progress will start off where it
ended.
### Environment Parameter Randomization

2
ml-agents/README.md


cooperative behavior among different agents is not stable.
- Resuming self-play from a checkpoint resets the reported ELO to the default
value.
- Resuming curriculum learning from a checkpoint requires the last lesson be
specified using the `--lesson` CLI option

7
ml-agents/mlagents/trainers/cli_utils.py


action=DetectDefault,
)
argparser.add_argument(
"--lesson",
default=0,
type=int,
help="The lesson to start with when performing curriculum training",
action=DetectDefault,
)
argparser.add_argument(
"--load",
default=False,
dest="load_model",

22
ml-agents/mlagents/trainers/learn.py


from mlagents.trainers.sampler_class import SamplerManager
from mlagents.trainers.exception import SamplerException
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.training_status import GlobalTrainingStatus
from mlagents_envs.base_env import BaseEnv
from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager
from mlagents_envs.side_channel.side_channel import SideChannel

from mlagents_envs import logging_util
logger = logging_util.get_logger(__name__)
TRAINING_STATUS_FILE_NAME = "training_status.json"
def get_version_string() -> str:

)
# Make run logs directory
os.makedirs(run_logs_dir, exist_ok=True)
# Load any needed states
if checkpoint_settings.resume:
GlobalTrainingStatus.load_state(
os.path.join(run_logs_dir, "training_status.json")
)
# Configure CSV, Tensorboard Writers and StatsReporter
# We assume reward and episode length are needed in the CSV.
csv_writer = CSVWriter(

env_factory, engine_config, env_settings.num_envs
)
maybe_meta_curriculum = try_create_meta_curriculum(
options.curriculum, env_manager, checkpoint_settings.lesson
options.curriculum, env_manager, restore=checkpoint_settings.resume
)
sampler_manager, resampling_interval = create_sampler_manager(
options.parameter_randomization, run_seed

env_manager.close()
write_run_options(write_path, options)
write_timing_tree(run_logs_dir)
write_training_status(run_logs_dir)
def write_run_options(output_dir: str, run_options: RunOptions) -> None:

)
def write_training_status(output_dir: str) -> None:
GlobalTrainingStatus.save_state(os.path.join(output_dir, TRAINING_STATUS_FILE_NAME))
def write_timing_tree(output_dir: str) -> None:
timing_path = os.path.join(output_dir, "timers.json")
try:

def try_create_meta_curriculum(
curriculum_config: Optional[Dict], env: SubprocessEnvManager, lesson: int
curriculum_config: Optional[Dict], env: SubprocessEnvManager, restore: bool = False
# TODO: Should be able to start learning at different lesson numbers
# for each curriculum.
meta_curriculum.set_all_curricula_to_lesson_num(lesson)
if restore:
meta_curriculum.try_restore_all_curriculum()
return meta_curriculum

25
ml-agents/mlagents/trainers/meta_curriculum.py


from typing import Dict, Set
from mlagents.trainers.curriculum import Curriculum
from mlagents.trainers.settings import CurriculumSettings
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
from mlagents_envs.logging_util import get_logger

)
return ret
def set_all_curricula_to_lesson_num(self, lesson_num):
"""Sets all the curricula in this meta curriculum to a specified
lesson number.
Args:
lesson_num (int): The lesson number which all the curricula will
be set to.
def try_restore_all_curriculum(self):
for _, curriculum in self.brains_to_curricula.items():
curriculum.lesson_num = lesson_num
Tries to restore all the curriculums to what is saved in training_status.json
"""
for brain_name, curriculum in self.brains_to_curricula.items():
lesson_num = GlobalTrainingStatus.get_parameter_state(
brain_name, StatusType.LESSON_NUM
)
if lesson_num is not None:
logger.info(
f"Resuming curriculum for {brain_name} at lesson {lesson_num}."
)
curriculum.lesson_num = lesson_num
else:
curriculum.lesson_num = 0
def get_config(self):
"""Get the combined configuration of all curricula in this

1
ml-agents/mlagents/trainers/settings.py


force: bool = parser.get_default("force")
train_model: bool = parser.get_default("train_model")
inference: bool = parser.get_default("inference")
lesson: int = parser.get_default("lesson")
@attr.s(auto_attribs=True)

7
ml-agents/mlagents/trainers/tests/test_learn.py


base_port: 4001
seed: 9870
checkpoint_settings:
lesson: 2
run_id: uselessrun
save_freq: 654321
debug: false

assert opt.behaviors == {}
assert opt.env_settings.env_path is None
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.lesson == 0
assert opt.checkpoint_settings.resume is False
assert opt.checkpoint_settings.inference is False
assert opt.checkpoint_settings.run_id == "ppo"

full_args = [
"mytrainerpath",
"--env=./myenvfile",
"--lesson=3",
"--resume",
"--inference",
"--run-id=myawesomerun",

assert opt.behaviors == {}
assert opt.env_settings.env_path == "./myenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.lesson == 3
assert opt.checkpoint_settings.run_id == "myawesomerun"
assert opt.checkpoint_settings.save_freq == 123456
assert opt.env_settings.seed == 7890

assert opt.behaviors == {}
assert opt.env_settings.env_path == "./oldenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.lesson == 2
assert opt.checkpoint_settings.run_id == "uselessrun"
assert opt.checkpoint_settings.save_freq == 654321
assert opt.env_settings.seed == 9870

full_args = [
"mytrainerpath",
"--env=./myenvfile",
"--lesson=3",
"--resume",
"--inference",
"--run-id=myawesomerun",

assert opt.behaviors == {}
assert opt.env_settings.env_path == "./myenvfile"
assert opt.parameter_randomization is None
assert opt.checkpoint_settings.lesson == 3
assert opt.checkpoint_settings.run_id == "myawesomerun"
assert opt.checkpoint_settings.save_freq == 123456
assert opt.env_settings.seed == 7890

23
ml-agents/mlagents/trainers/tests/test_meta_curriculum.py


import pytest
from unittest.mock import patch, Mock
from unittest.mock import patch, Mock, call
from mlagents.trainers.meta_curriculum import MetaCurriculum

)
from mlagents.trainers.tests.test_curriculum import dummy_curriculum_config
from mlagents.trainers.settings import CurriculumSettings
from mlagents.trainers.training_status import StatusType
@pytest.fixture

curriculum_b.increment_lesson.assert_not_called()
def test_set_all_curriculums_to_lesson_num():
@patch("mlagents.trainers.meta_curriculum.GlobalTrainingStatus")
def test_restore_curriculums(mock_trainingstatus):
meta_curriculum.set_all_curricula_to_lesson_num(2)
# Test restore to value
mock_trainingstatus.get_parameter_state.return_value = 2
meta_curriculum.try_restore_all_curriculum()
mock_trainingstatus.get_parameter_state.assert_has_calls(
[call("Brain1", StatusType.LESSON_NUM), call("Brain2", StatusType.LESSON_NUM)],
any_order=True,
)
# Test restore to None
mock_trainingstatus.get_parameter_state.return_value = None
meta_curriculum.try_restore_all_curriculum()
assert meta_curriculum.brains_to_curricula["Brain1"].lesson_num == 0
assert meta_curriculum.brains_to_curricula["Brain2"].lesson_num == 0
def test_get_config():

4
ml-agents/mlagents/trainers/trainer_controller.py


from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.agent_processor import AgentManager
from mlagents.trainers.settings import CurriculumSettings
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
class TrainerController(object):

if brain_name in self.trainers:
self.trainers[brain_name].stats_reporter.set_stat(
"Environment/Lesson", curr.lesson_num
)
GlobalTrainingStatus.set_parameter_state(
brain_name, StatusType.LESSON_NUM, curr.lesson_num
)
for trainer in self.trainers.values():

60
ml-agents/mlagents/trainers/tests/test_training_status.py


import os
import unittest
import json
from enum import Enum
from mlagents.trainers.training_status import (
StatusType,
StatusMetaData,
GlobalTrainingStatus,
)
def test_globaltrainingstatus(tmpdir):
path_dir = os.path.join(tmpdir, "test.json")
GlobalTrainingStatus.set_parameter_state("Category1", StatusType.LESSON_NUM, 3)
GlobalTrainingStatus.save_state(path_dir)
with open(path_dir, "r") as fp:
test_json = json.load(fp)
assert "Category1" in test_json
assert StatusType.LESSON_NUM.value in test_json["Category1"]
assert test_json["Category1"][StatusType.LESSON_NUM.value] == 3
assert "metadata" in test_json
GlobalTrainingStatus.load_state(path_dir)
restored_val = GlobalTrainingStatus.get_parameter_state(
"Category1", StatusType.LESSON_NUM
)
assert restored_val == 3
# Test unknown categories and status types (keys)
unknown_category = GlobalTrainingStatus.get_parameter_state(
"Category3", StatusType.LESSON_NUM
)
class FakeStatusType(Enum):
NOTAREALKEY = "notarealkey"
unknown_key = GlobalTrainingStatus.get_parameter_state(
"Category1", FakeStatusType.NOTAREALKEY
)
assert unknown_category is None
assert unknown_key is None
class StatsMetaDataTest(unittest.TestCase):
def test_metadata_compare(self):
# Test write_stats
with self.assertLogs("mlagents.trainers", level="WARNING") as cm:
default_metadata = StatusMetaData()
version_statsmetadata = StatusMetaData(mlagents_version="test")
default_metadata.check_compatibility(version_statsmetadata)
tf_version_statsmetadata = StatusMetaData(tensorflow_version="test")
default_metadata.check_compatibility(tf_version_statsmetadata)
# Assert that 2 warnings have been thrown
assert len(cm.output) == 2

115
ml-agents/mlagents/trainers/training_status.py


from typing import Dict, Any
from enum import Enum
from collections import defaultdict
import json
import attr
import cattr
from mlagents.tf_utils import tf
from mlagents_envs.logging_util import get_logger
from mlagents.trainers import __version__
from mlagents.trainers.exception import TrainerError
logger = get_logger(__name__)
STATUS_FORMAT_VERSION = "0.1.0"
class StatusType(Enum):
LESSON_NUM = "lesson_num"
STATS_METADATA = "metadata"
@attr.s(auto_attribs=True)
class StatusMetaData:
stats_format_version: str = STATUS_FORMAT_VERSION
mlagents_version: str = __version__
tensorflow_version: str = tf.__version__
def to_dict(self) -> Dict[str, str]:
return cattr.unstructure(self)
@staticmethod
def from_dict(import_dict: Dict[str, str]) -> "StatusMetaData":
return cattr.structure(import_dict, StatusMetaData)
def check_compatibility(self, other: "StatusMetaData") -> None:
"""
Check compatibility with a loaded StatsMetaData and warn the user
if versions mismatch. This is used for resuming from old checkpoints.
"""
# This should cover all stats version mismatches as well.
if self.mlagents_version != other.mlagents_version:
logger.warning(
"Checkpoint was loaded from a different version of ML-Agents. Some things may not resume properly."
)
if self.tensorflow_version != other.tensorflow_version:
logger.warning(
"Tensorflow checkpoint was saved with a different version of Tensorflow. Model may not resume properly."
)
class GlobalTrainingStatus:
"""
GlobalTrainingStatus class that contains static methods to save global training status and
load it on a resume. These are values that might be needed for the training resume that
cannot/should not be captured in a model checkpoint, such as curriclum lesson.
"""
saved_state: Dict[str, Dict[str, Any]] = defaultdict(lambda: {})
@staticmethod
def load_state(path: str) -> None:
"""
Load a JSON file that contains saved state.
:param path: Path to the JSON file containing the state.
"""
try:
with open(path, "r") as f:
loaded_dict = json.load(f)
# Compare the metadata
_metadata = loaded_dict[StatusType.STATS_METADATA.value]
StatusMetaData.from_dict(_metadata).check_compatibility(StatusMetaData())
# Update saved state.
GlobalTrainingStatus.saved_state.update(loaded_dict)
except FileNotFoundError:
logger.warning(
"Training status file not found. Not all functions will resume properly."
)
except KeyError:
raise TrainerError(
"Metadata not found, resuming from an incompatible version of ML-Agents."
)
@staticmethod
def save_state(path: str) -> None:
"""
Save a JSON file that contains saved state.
:param path: Path to the JSON file containing the state.
"""
GlobalTrainingStatus.saved_state[
StatusType.STATS_METADATA.value
] = StatusMetaData().to_dict()
with open(path, "w") as f:
json.dump(GlobalTrainingStatus.saved_state, f, indent=4)
@staticmethod
def set_parameter_state(category: str, key: StatusType, value: Any) -> None:
"""
Stores an arbitrary-named parameter in the global saved state.
:param category: The category (usually behavior name) of the parameter.
:param key: The parameter, e.g. lesson number.
:param value: The value.
"""
GlobalTrainingStatus.saved_state[category][key.value] = value
@staticmethod
def get_parameter_state(category: str, key: StatusType) -> Any:
"""
Loads an arbitrary-named parameter from training_status.json.
If not found, returns None.
:param category: The category (usually behavior name) of the parameter.
:param key: The statistic, e.g. lesson number.
:param value: The value.
"""
return GlobalTrainingStatus.saved_state[category].get(key.value, None)
正在加载...
取消
保存