浏览代码

Add 'run-experiment' script, simpler curriculum config (#3186)

This change adds a new 'mlagents-run-experiment' endpoint which
accepts a single YAML/JSON file providing all of the information that
mlagents-learn accepts via command-line arguments and file inputs.
As part of this change the curriculum configuration is simplified to
accept only a single file for all the curricula in an environment
rather than a file for each behavior.
/asymm-envs
GitHub 5 年前
当前提交
b0a2a54f
共有 12 个文件被更改,包括 287 次插入266 次删除
  1. 4
      docs/Migrating.md
  2. 87
      docs/Training-Curriculum-Learning.md
  3. 214
      ml-agents/mlagents/trainers/learn.py
  4. 42
      ml-agents/mlagents/trainers/meta_curriculum.py
  5. 20
      ml-agents/mlagents/trainers/tests/test_learn.py
  6. 128
      ml-agents/mlagents/trainers/tests/test_meta_curriculum.py
  7. 7
      ml-agents/setup.py
  8. 9
      config/curricula/test.yaml
  9. 16
      config/curricula/wall_jump.yaml
  10. 26
      ml-agents/mlagents/trainers/run_experiment.py

4
docs/Migrating.md


### Important changes
* Trainer steps are now counted per-Agent, not per-environment as in previous versions. For instance, if you have 10 Agents in the scene, 20 environment steps now corresponds to 200 steps as printed in the terminal and in Tensorboard.
* Curriculum config files are now YAML formatted and all curricula for a training run are combined into a single file.
* The `--num-runs` command-line option has been removed.
* Combine curriculum configs into a single file. See [the WallJump curricula](../config/curricula/wall_jump.yaml) for an example of the new curriculum config format.
A tool like https://www.json2yaml.com may be useful to help with the conversion.
## Migrating from ML-Agents toolkit v0.12.0 to v0.13.0

87
docs/Training-Curriculum-Learning.md


likely never, or very rarely scale the wall properly to the achieve the reward.
If we start with a simpler task, such as moving toward an unobstructed goal,
then the agent can easily learn to accomplish the task. From there, we can
slowly add to the difficulty of the task by increasing the size of the wall,
slowly add to the difficulty of the task by increasing the size of the wall
wall. We are including just such an environment with the ML-Agents toolkit 0.2,
wall. We have included an environment to demonstrate this with ML-Agents,
called __Wall Jump__.
![Wall](images/curriculum.png)

To see this in action, observe the two learning curves below. Each displays the
reward over time for an agent trained using PPO with the same set of training
hyperparameters. The difference is that one agent was trained using the
To see curriculum learning in action, observe the two learning curves below. Each
displays the reward over time for an agent trained using PPO with the same set of
training hyperparameters. The difference is that one agent was trained using the
full-height wall version of the task, and the other agent was trained using the
curriculum version of the task. As you can see, without using curriculum
learning the agent has a lot of difficulty. We think that by using well-crafted

## How-To
Each group of Agents under the same `Behavior Name` in an environment can have
a corresponding curriculum. These
curriculums are held in what we call a metacurriculum. A metacurriculum allows
different groups of Agents to follow different curriculums within the same environment.
### Specifying a Metacurriculum
We first create a folder inside `config/curricula/` for the environment we want
to use curriculum learning with. For example, if we were creating a
metacurriculum for Wall Jump, we would create the folder
`config/curricula/wall-jump/`. We will place our curriculums inside this folder.
a corresponding curriculum. These curricula are held in what we call a "metacurriculum".
A metacurriculum allows different groups of Agents to follow different curricula within
the same environment.
### Specifying a Curriculum
### Specifying Curricula
In order to define a curriculum, the first step is to decide which parameters of
the environment will vary. In the case of the Wall Jump environment, what varies
is the height of the wall. We define this as a `Shared Float Property` that
can be accessed in `Academy.FloatProperties`, and by doing so it becomes
In order to define the curricula, the first step is to decide which parameters of
the environment will vary. In the case of the Wall Jump environment,
the height of the wall is what varies. We define this as a `Shared Float Property`
that can be accessed in `Academy.FloatProperties`, and by doing so it becomes
Rather than adjusting it by hand, we will create a JSON file which
describes the structure of the curriculum. Within it, we can specify which
Rather than adjusting it by hand, we will create a YAML file which
describes the structure of the curricula. Within it, we can specify which
the agent has received in the recent past is. Below is an example curriculum for
the BigWallBehavior in the Wall Jump environment.
the agent has received in the recent past is. Below is an example config for the
curricula for the Wall Jump environment.
```json
{
"measure" : "progress",
"thresholds" : [0.1, 0.3, 0.5],
"min_lesson_length" : 100,
"signal_smoothing" : true,
"parameters" :
{
"big_wall_min_height" : [0.0, 4.0, 6.0, 8.0],
"big_wall_max_height" : [4.0, 7.0, 8.0, 8.0]
}
}
```yaml
BigWallJump:
measure: progress
thresholds: [0.1, 0.3, 0.5]
min_lesson_length: 100
signal_smoothing: true
parameters:
big_wall_min_height: [0.0, 4.0, 6.0, 8.0]
big_wall_max_height: [4.0, 7.0, 8.0, 8.0]
SmallWallJump:
measure: progress
thresholds: [0.1, 0.3, 0.5]
min_lesson_length: 100
signal_smoothing: true
parameters:
small_wall_height: [1.5, 2.0, 2.5, 4.0]
At the top level of the config is the behavior name. The curriculum for each
behavior has the following parameters:
* `measure` - What to measure learning progress, and advancement in lessons by.
* `reward` - Uses a measure received reward.
* `progress` - Uses ratio of steps/max_steps.

[WallJumpAgent.cs](https://github.com/Unity-Technologies/ml-agents/blob/master/UnitySDK/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs)
for an example.
We will save this file into our metacurriculum folder with the name of its
corresponding `Behavior Name`. For example, in the Wall Jump environment, there are two
different `Behaviors Name` set via script in `WallJumpAgent.cs`
---BigWallBrainLearning and SmallWallBrainLearning. If we want to define a curriculum for
the BigWallBrainLearning, we will save `BigWallBrainLearning.json` into
`config/curricula/wall-jump/`.
Once we have specified our metacurriculum and curriculums, we can launch
`mlagents-learn` using the `–curriculum` flag to point to the metacurriculum
folder and PPO will train using Curriculum Learning. For example, to train
agents in the Wall Jump environment with curriculum learning, we can run
Once we have specified our metacurriculum and curricula, we can launch
`mlagents-learn` using the `–curriculum` flag to point to the config file
for our curricula and PPO will train using Curriculum Learning. For example,
to train agents in the Wall Jump environment with curriculum learning, we can run:
mlagents-learn config/trainer_config.yaml --curriculum=config/curricula/wall-jump/ --run-id=wall-jump-curriculum --train
mlagents-learn config/trainer_config.yaml --curriculum=config/curricula/wall_jump.yaml --run-id=wall-jump-curriculum --train
```
We can then keep track of the current lessons and progresses via TensorBoard.

214
ml-agents/mlagents/trainers/learn.py


import glob
import shutil
import numpy as np
import json
from typing import Any, Callable, Optional, List, NamedTuple
from typing import Callable, Optional, List, NamedTuple, Dict
import mlagents.trainers
import mlagents_envs

from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
class CommandLineOptions(NamedTuple):
debug: bool
seed: int
env_path: str
run_id: str
load_model: bool
train_model: bool
save_freq: int
keep_checkpoints: int
base_port: int
num_envs: int
curriculum_folder: Optional[str]
lesson: int
no_graphics: bool
multi_gpu: bool # ?
trainer_config_path: str
sampler_file_path: Optional[str]
docker_target_name: Optional[str]
env_args: Optional[List[str]]
cpu: bool
width: int
height: int
quality_level: int
time_scale: float
target_frame_rate: int
@staticmethod
def from_argparse(args: Any) -> "CommandLineOptions":
return CommandLineOptions(**vars(args))
def get_version_string() -> str:
# pylint: disable=no-member
return f""" Version information:
ml-agents: {mlagents.trainers.__version__},
ml-agents-envs: {mlagents_envs.__version__},
Communicator API: {UnityEnvironment.API_VERSION},
TensorFlow: {tf_utils.tf.__version__}"""
def parse_command_line(argv: Optional[List[str]] = None) -> CommandLineOptions:
parser = argparse.ArgumentParser(
def _create_parser():
argparser = argparse.ArgumentParser(
parser.add_argument("trainer_config_path")
parser.add_argument(
argparser.add_argument("trainer_config_path")
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
dest="curriculum_folder",
help="Curriculum json directory for environment",
dest="curriculum_config_path",
help="Curriculum config yaml file for environment",
parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
"--load",
default=False,
dest="load_model",

parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
"--train",
default=False,
dest="train_model",

parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
parser.add_argument(
argparser.add_argument(
parser.add_argument("--version", action="version", version="")
argparser.add_argument("--version", action="version", version="")
eng_conf = parser.add_argument_group(title="Engine Configuration")
eng_conf = argparser.add_argument_group(title="Engine Configuration")
eng_conf.add_argument(
"--width",
default=84,

type=int,
help="The target frame rate of the Unity environment(s)",
)
return argparser
parser = _create_parser()
class RunOptions(NamedTuple):
trainer_config: Dict
debug: bool = parser.get_default("debug")
seed: int = parser.get_default("seed")
env_path: Optional[str] = parser.get_default("env_path")
run_id: str = parser.get_default("run_id")
load_model: bool = parser.get_default("load_model")
train_model: bool = parser.get_default("train_model")
save_freq: int = parser.get_default("save_freq")
keep_checkpoints: int = parser.get_default("keep_checkpoints")
base_port: int = parser.get_default("base_port")
num_envs: int = parser.get_default("num_envs")
curriculum_config: Optional[Dict] = None
lesson: int = parser.get_default("lesson")
no_graphics: bool = parser.get_default("no_graphics")
multi_gpu: bool = parser.get_default("multi_gpu")
sampler_config: Optional[Dict] = None
docker_target_name: Optional[str] = parser.get_default("docker_target_name")
env_args: Optional[List[str]] = parser.get_default("env_args")
cpu: bool = parser.get_default("cpu")
width: int = parser.get_default("width")
height: int = parser.get_default("height")
quality_level: int = parser.get_default("quality_level")
time_scale: float = parser.get_default("time_scale")
target_frame_rate: int = parser.get_default("target_frame_rate")
@staticmethod
def from_argparse(args: argparse.Namespace) -> "RunOptions":
"""
Takes an argparse.Namespace as specified in `parse_command_line`, loads input configuration files
from file paths, and converts to a CommandLineOptions instance.
:param args: collection of command-line parameters passed to mlagents-learn
:return: CommandLineOptions representing the passed in arguments, with trainer config, curriculum and sampler
configs loaded from files.
"""
argparse_args = vars(args)
docker_target_name = argparse_args["docker_target_name"]
trainer_config_path = argparse_args["trainer_config_path"]
curriculum_config_path = argparse_args["curriculum_config_path"]
if docker_target_name is not None:
trainer_config_path = f"/{docker_target_name}/{trainer_config_path}"
if curriculum_config_path is not None:
curriculum_config_path = (
f"/{docker_target_name}/{curriculum_config_path}"
)
argparse_args["trainer_config"] = load_config(trainer_config_path)
if curriculum_config_path is not None:
argparse_args["curriculum_config"] = load_config(curriculum_config_path)
if argparse_args["sampler_file_path"] is not None:
argparse_args["sampler_config"] = load_config(
argparse_args["sampler_file_path"]
)
# Since argparse accepts file paths in the config options which don't exist in CommandLineOptions,
# these keys will need to be deleted to use the **/splat operator below.
argparse_args.pop("sampler_file_path")
argparse_args.pop("curriculum_config_path")
argparse_args.pop("trainer_config_path")
return RunOptions(**vars(args))
def get_version_string() -> str:
# pylint: disable=no-member
return f""" Version information:
ml-agents: {mlagents.trainers.__version__},
ml-agents-envs: {mlagents_envs.__version__},
Communicator API: {UnityEnvironment.API_VERSION},
TensorFlow: {tf_utils.tf.__version__}"""
def parse_command_line(argv: Optional[List[str]] = None) -> RunOptions:
return CommandLineOptions.from_argparse(args)
return RunOptions.from_argparse(args)
def run_training(run_seed: int, options: CommandLineOptions) -> None:
def run_training(run_seed: int, options: RunOptions) -> None:
"""
Launches training session.
:param options: parsed command line arguments

# Docker Parameters
trainer_config_path = options.trainer_config_path
curriculum_folder = options.curriculum_folder
trainer_config_path = f"/{options.docker_target_name}/{trainer_config_path}"
if curriculum_folder is not None:
curriculum_folder = f"/{options.docker_target_name}/{curriculum_folder}"
trainer_config = load_config(trainer_config_path)
port = options.base_port
# Configure CSV, Tensorboard Writers and StatsReporter

)
env_manager = SubprocessEnvManager(env_factory, engine_config, options.num_envs)
maybe_meta_curriculum = try_create_meta_curriculum(
curriculum_folder, env_manager, options.lesson
options.curriculum_config, env_manager, options.lesson
options.sampler_file_path, run_seed
options.sampler_config, run_seed
trainer_config,
options.trainer_config,
summaries_dir,
options.run_id,
model_path,

env_manager.close()
def create_sampler_manager(sampler_file_path, run_seed=None):
sampler_config = None
def create_sampler_manager(sampler_config, run_seed=None):
if sampler_file_path is not None:
sampler_config = load_config(sampler_file_path)
if sampler_config is not None:
if "resampling-interval" in sampler_config:
# Filter arguments that do not exist in the environment
resample_interval = sampler_config.pop("resampling-interval")

def try_create_meta_curriculum(
curriculum_folder: Optional[str], env: SubprocessEnvManager, lesson: int
curriculum_config: Optional[Dict], env: SubprocessEnvManager, lesson: int
if curriculum_folder is None:
if curriculum_config is None:
meta_curriculum = MetaCurriculum.from_directory(curriculum_folder)
meta_curriculum = MetaCurriculum(curriculum_config)
return meta_curriculum

def create_environment_factory(
env_path: str,
env_path: Optional[str],
docker_target_name: Optional[str],
no_graphics: bool,
seed: Optional[int],

return create_unity_environment
def main():
def run_cli(options: RunOptions) -> None:
try:
print(
"""

except Exception:
print("\n\n\tUnity Technologies\n")
print(get_version_string())
options = parse_command_line()
trainer_logger.info(options)
if options.debug:
trainer_logger.setLevel("DEBUG")
env_logger.setLevel("DEBUG")

trainer_logger.debug("Configuration for this run:")
trainer_logger.debug(json.dumps(options._asdict(), indent=4))
run_seed = options.seed
if options.cpu:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

run_training(run_seed, options)
def main():
run_cli(parse_command_line())
# For python debugger to directly run this script

42
ml-agents/mlagents/trainers/meta_curriculum.py


"""Contains the MetaCurriculum class."""
import os
from mlagents.trainers.exception import MetaCurriculumError
import logging

particular brain in the environment.
"""
def __init__(self, curricula: Dict[str, Curriculum]):
def __init__(self, curriculum_configs: Dict[str, Dict]):
used_reset_parameters: Set[str] = set()
for brain_name, curriculum in curricula.items():
self._brains_to_curricula[brain_name] = curriculum
config_keys: Set[str] = set(curriculum.get_config().keys())
used_reset_parameters: Set[str] = set()
for brain_name, curriculum_config in curriculum_configs.items():
self._brains_to_curricula[brain_name] = Curriculum(
brain_name, curriculum_config
)
config_keys: Set[str] = set(
self._brains_to_curricula[brain_name].get_config().keys()
)
# Check if any two curricula use the same reset params.
if config_keys & used_reset_parameters:

config.update(curr_config)
return config
@staticmethod
def from_directory(folder_path: str) -> "MetaCurriculum":
"""
Creates a MetaCurriculum given a folder full of curriculum config files.
:param folder_path: The path to the folder which holds the curriculum configs
for this environment. The folder should contain JSON files whose names
are the brains that the curricula belong to.
"""
try:
curricula = {}
for curriculum_filename in os.listdir(folder_path):
# This process requires JSON files
brain_name, extension = os.path.splitext(curriculum_filename)
if extension.lower() != ".json":
continue
curriculum_filepath = os.path.join(folder_path, curriculum_filename)
curriculum_config = Curriculum.load_curriculum_file(curriculum_filepath)
curricula[brain_name] = Curriculum(brain_name, curriculum_config)
return MetaCurriculum(curricula)
except NotADirectoryError:
raise MetaCurriculumError(
f"{folder_path} is not a directory. Refer to the ML-Agents "
"curriculum learning docs."
)

20
ml-agents/mlagents/trainers/tests/test_learn.py


import pytest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, mock_open
from mlagents.trainers import learn
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.learn import parse_command_line

assert mock_init.call_args[0][2] == "/dockertarget/summaries"
def test_commandline_args():
@patch("builtins.open", new_callable=mock_open, read_data="{}")
def test_commandline_args(mock_file):
# No args raises
with pytest.raises(SystemExit):

opt = parse_command_line(["mytrainerpath"])
assert opt.trainer_config_path == "mytrainerpath"
assert opt.trainer_config == {}
assert opt.curriculum_folder is None
assert opt.sampler_file_path is None
assert opt.curriculum_config is None
assert opt.sampler_config is None
assert opt.keep_checkpoints == 5
assert opt.lesson == 0
assert opt.load_model is False

]
opt = parse_command_line(full_args)
assert opt.trainer_config_path == "mytrainerpath"
assert opt.trainer_config == {}
assert opt.curriculum_folder == "./mycurriculum"
assert opt.sampler_file_path == "./mysample"
assert opt.curriculum_config == {}
assert opt.sampler_config == {}
assert opt.keep_checkpoints == 42
assert opt.lesson == 3
assert opt.load_model is True

assert opt.multi_gpu is True
def test_env_args():
@patch("builtins.open", new_callable=mock_open, read_data="{}")
def test_env_args(mock_file):
full_args = [
"mytrainerpath",
"--env=./myenvfile",

128
ml-agents/mlagents/trainers/tests/test_meta_curriculum.py


import pytest
from unittest.mock import patch, call, mock_open
from unittest.mock import patch, Mock
from mlagents.trainers.curriculum import Curriculum
from mlagents.trainers.exception import MetaCurriculumError
import json
from mlagents.trainers.tests.test_simple_rl import (
Simple1DEnvironment,

from mlagents.trainers.tests.test_curriculum import (
dummy_curriculum_json_str,
dummy_curriculum_config,
)
@pytest.fixture
def default_reset_parameters():
return {"param1": 1, "param2": 2, "param3": 3}
@pytest.fixture
def more_reset_parameters():
return {"param4": 4, "param5": 5, "param6": 6}
from mlagents.trainers.tests.test_curriculum import dummy_curriculum_json_str
@pytest.fixture

return {"Brain1": 7, "Brain2": 8}
@patch("mlagents.trainers.curriculum.Curriculum.get_config", return_value={})
@patch(
"mlagents.trainers.curriculum.Curriculum.load_curriculum_file",
return_value=dummy_curriculum_config,
)
@patch("os.listdir", return_value=["Brain1.json", "Brain2.test.json"])
def test_init_meta_curriculum_happy_path(
listdir, mock_curriculum_init, mock_curriculum_get_config, default_reset_parameters
):
meta_curriculum = MetaCurriculum.from_directory("test/")
assert len(meta_curriculum.brains_to_curricula) == 2
assert "Brain1" in meta_curriculum.brains_to_curricula
assert "Brain2.test" in meta_curriculum.brains_to_curricula
calls = [call("test/Brain1.json"), call("test/Brain2.test.json")]
mock_curriculum_init.assert_has_calls(calls)
def test_curriculum_config(param_name="test_param1", min_lesson_length=100):
return {
"measure": "progress",
"thresholds": [0.1, 0.3, 0.5],
"min_lesson_length": min_lesson_length,
"signal_smoothing": True,
"parameters": {f"{param_name}": [0.0, 4.0, 6.0, 8.0]},
}
@patch("os.listdir", side_effect=NotADirectoryError())
def test_init_meta_curriculum_bad_curriculum_folder_raises_error(listdir):
with pytest.raises(MetaCurriculumError):
MetaCurriculum.from_directory("test/")
test_meta_curriculum_config = {
"Brain1": test_curriculum_config("test_param1"),
"Brain2": test_curriculum_config("test_param2"),
}
@patch("mlagents.trainers.curriculum.Curriculum")
@patch("mlagents.trainers.curriculum.Curriculum")
def test_set_lesson_nums(curriculum_a, curriculum_b):
meta_curriculum = MetaCurriculum({"Brain1": curriculum_a, "Brain2": curriculum_b})
def test_set_lesson_nums():
meta_curriculum = MetaCurriculum(test_meta_curriculum_config)
assert curriculum_a.lesson_num == 1
assert curriculum_b.lesson_num == 3
assert meta_curriculum.brains_to_curricula["Brain1"].lesson_num == 1
assert meta_curriculum.brains_to_curricula["Brain2"].lesson_num == 3
@patch("mlagents.trainers.curriculum.Curriculum")
@patch("mlagents.trainers.curriculum.Curriculum")
def test_increment_lessons(curriculum_a, curriculum_b, measure_vals):
meta_curriculum = MetaCurriculum({"Brain1": curriculum_a, "Brain2": curriculum_b})
def test_increment_lessons(measure_vals):
meta_curriculum = MetaCurriculum(test_meta_curriculum_config)
meta_curriculum.brains_to_curricula["Brain1"] = Mock()
meta_curriculum.brains_to_curricula["Brain2"] = Mock()
curriculum_a.increment_lesson.assert_called_with(0.2)
curriculum_b.increment_lesson.assert_called_with(0.3)
meta_curriculum.brains_to_curricula["Brain1"].increment_lesson.assert_called_with(
0.2
)
meta_curriculum.brains_to_curricula["Brain2"].increment_lesson.assert_called_with(
0.3
)
@patch("mlagents.trainers.curriculum.Curriculum")

):
curriculum_a.min_lesson_length = 5
curriculum_b.min_lesson_length = 10
meta_curriculum = MetaCurriculum({"Brain1": curriculum_a, "Brain2": curriculum_b})
meta_curriculum = MetaCurriculum(test_meta_curriculum_config)
meta_curriculum.brains_to_curricula["Brain1"] = curriculum_a
meta_curriculum.brains_to_curricula["Brain2"] = curriculum_b
meta_curriculum.increment_lessons(measure_vals, reward_buff_sizes=reward_buff_sizes)

@patch("mlagents.trainers.curriculum.Curriculum")
@patch("mlagents.trainers.curriculum.Curriculum")
def test_set_all_curriculums_to_lesson_num(curriculum_a, curriculum_b):
meta_curriculum = MetaCurriculum({"Brain1": curriculum_a, "Brain2": curriculum_b})
def test_set_all_curriculums_to_lesson_num():
meta_curriculum = MetaCurriculum(test_meta_curriculum_config)
assert curriculum_a.lesson_num == 2
assert curriculum_b.lesson_num == 2
@patch("mlagents.trainers.curriculum.Curriculum")
@patch("mlagents.trainers.curriculum.Curriculum")
def test_get_config(
curriculum_a, curriculum_b, default_reset_parameters, more_reset_parameters
):
curriculum_a.get_config.return_value = default_reset_parameters
curriculum_b.get_config.return_value = default_reset_parameters
meta_curriculum = MetaCurriculum({"Brain1": curriculum_a, "Brain2": curriculum_b})
assert meta_curriculum.get_config() == default_reset_parameters
curriculum_b.get_config.return_value = more_reset_parameters
assert meta_curriculum.brains_to_curricula["Brain1"].lesson_num == 2
assert meta_curriculum.brains_to_curricula["Brain2"].lesson_num == 2
new_reset_parameters = dict(default_reset_parameters)
new_reset_parameters.update(more_reset_parameters)
assert meta_curriculum.get_config() == new_reset_parameters
def test_get_config():
meta_curriculum = MetaCurriculum(test_meta_curriculum_config)
assert meta_curriculum.get_config() == {"test_param1": 0.0, "test_param2": 0.0}
META_CURRICULUM_CONFIG = """
TRAINER_CONFIG = """
default:
trainer: ppo
batch_size: 16

@pytest.mark.parametrize("curriculum_brain_name", [BRAIN_NAME, "WrongBrainName"])
def test_simple_metacurriculum(curriculum_brain_name):
env = Simple1DEnvironment(use_discrete=False)
with patch(
"builtins.open", new_callable=mock_open, read_data=dummy_curriculum_json_str
):
curriculum_config = Curriculum.load_curriculum_file("TestBrain.json")
curriculum = Curriculum("TestBrain", curriculum_config)
mc = MetaCurriculum({curriculum_brain_name: curriculum})
_check_environment_trains(env, META_CURRICULUM_CONFIG, mc, -100.0)
curriculum_config = json.loads(dummy_curriculum_json_str)
mc = MetaCurriculum({curriculum_brain_name: curriculum_config})
_check_environment_trains(env, TRAINER_CONFIG, mc, -100.0)

7
ml-agents/setup.py


'pypiwin32==223;platform_system=="Windows"',
],
python_requires=">=3.6.1",
entry_points={"console_scripts": ["mlagents-learn=mlagents.trainers.learn:main"]},
entry_points={
"console_scripts": [
"mlagents-learn=mlagents.trainers.learn:main",
"mlagents-run-experiment=mlagents.trainers.run_experiment:main",
]
},
cmdclass={"verify": VerifyVersionCommand},
)

9
config/curricula/test.yaml


TestBrain:
measure: reward
thresholds: [10, 20, 50]
min_lesson_length: 100
signal_smoothing: true
parameters:
param1: [0.7, 0.5, 0.3, 0.1]
param2: [100, 50, 20, 15]
param3: [0.2, 0.3, 0.7, 0.9]

16
config/curricula/wall_jump.yaml


BigWallJump:
measure: progress
thresholds: [0.1, 0.3, 0.5]
min_lesson_length: 100
signal_smoothing: true
parameters:
big_wall_min_height: [0.0, 4.0, 6.0, 8.0]
big_wall_max_height: [4.0, 7.0, 8.0, 8.0]
SmallWallJump:
measure: progress
thresholds: [0.1, 0.3, 0.5]
min_lesson_length: 100
signal_smoothing: true
parameters:
small_wall_height: [1.5, 2.0, 2.5, 4.0]

26
ml-agents/mlagents/trainers/run_experiment.py


import argparse
from typing import Optional, List
from mlagents.trainers.learn import RunOptions, run_cli, load_config
def parse_command_line(argv: Optional[List[str]] = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument("experiment_config_path")
return parser.parse_args(argv)
def main():
"""
Provides an alternative CLI interface to mlagents-learn, 'mlagents-run-experiment'.
Accepts a JSON/YAML formatted mlagents.trainers.learn.RunOptions object, and executes
the run loop as defined in mlagents.trainers.learn.run_cli.
"""
args = parse_command_line()
expt_config = load_config(args.experiment_config_path)
run_cli(RunOptions(**expt_config))
if __name__ == "__main__":
main()
正在加载...
取消
保存