浏览代码

Merge branch 'master' into internal-policy-ghost

/internal-policy-ghost
Andrew Cohen 5 年前
当前提交
7a7eb324
共有 11 个文件被更改,包括 182 次插入22 次删除
  1. 7
      Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
  2. 1
      com.unity.ml-agents/CHANGELOG.md
  3. 8
      docs/Training-ML-Agents.md
  4. 11
      docs/Training-PPO.md
  5. 11
      docs/Training-SAC.md
  6. 20
      ml-agents/mlagents/trainers/learn.py
  7. 56
      ml-agents/mlagents/trainers/policy/tf_policy.py
  8. 2
      ml-agents/mlagents/trainers/tests/test_learn.py
  9. 49
      ml-agents/mlagents/trainers/tests/test_nn_policy.py
  10. 12
      ml-agents/mlagents/trainers/tests/test_trainer_util.py
  11. 27
      ml-agents/mlagents/trainers/trainer_util.py

7
Project/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs


public override float[] Heuristic()
{
var action = new float[2];
var action = new float[3];
action[0] = Input.GetAxis("Horizontal");
action[1] = Input.GetKey(KeyCode.Space) ? 1f : 0f;
action[0] = Input.GetAxis("Horizontal"); // Racket Movement
action[1] = Input.GetKey(KeyCode.Space) ? 1f : 0f; // Racket Jumping
action[2] = Input.GetAxis("Vertical"); // Racket Rotation
return action;
}

1
com.unity.ml-agents/CHANGELOG.md


- The Jupyter notebooks have been removed from the repository.
- Introduced the `SideChannelUtils` to register, unregister and access side channels.
- `Academy.FloatProperties` was removed, please use `SideChannelUtils.GetSideChannel<FloatPropertiesChannel>()` instead.
- Added ability to start training (initialize model weights) from a previous run ID. (#3710)
### Minor Changes
- Format of console output has changed slightly and now matches the name of the model/summary directory. (#3630, #3616)

8
docs/Training-ML-Agents.md


specified, you will not be able to continue with training. Use `--force` to force ML-Agents to
overwrite the existing data.
Alternatively, you might want to start a new training run but _initialize_ it using an already-trained
model. You may want to do this, for instance, if your environment changed and you want
a new model, but the old behavior is still better than random. You can do this by specifying `--initialize-from=<run-identifier>`, where `<run-identifier>` is the old run ID.
### Command Line Training Options
In addition to passing the path of the Unity executable containing your training

as the current agents in your scene.
* `--force`: Attempting to train a model with a run-id that has been used before will
throw an error. Use `--force` to force-overwrite this run-id's summary and model data.
* `--initialize-from=<run-identifier>`: Specify an old run-id here to initialize your model from
a previously trained model. Note that the previously saved models _must_ have the same behavior
parameters as your current environment.
* `--no-graphics`: Specify this option to run the Unity executable in
`-batchmode` and doesn't initialize the graphics driver. Use this only if your
training doesn't involve visual observations (reading from Pixels). See

| train_interval | How often to update the agent. | SAC |
| num_update | Number of mini-batches to update the agent with during each update. | SAC |
| use_recurrent | Train using a recurrent neural network. See [Using Recurrent Neural Networks](Feature-Memory.md). | PPO, SAC |
| init_path | Initialize trainer from a previously saved model. | PPO, SAC |
\*PPO = Proximal Policy Optimization, SAC = Soft Actor-Critic, BC = Behavioral Cloning (Imitation), GAIL = Generative Adversarial Imitaiton Learning

11
docs/Training-PPO.md


Typical Range: Approximately equal to PPO's `buffer_size`
### (Optional) Advanced: Initialize Model Path
`init_path` can be specified to initialize your model from a previous run before starting.
Note that the prior run should have used the same trainer configurations as the current run,
and have been saved with the same version of ML-Agents. You should provide the full path
to the folder where the checkpoints were saved, e.g. `./models/{run-id}/{behavior_name}`.
This option is provided in case you want to initialize different behaviors from different runs;
in most cases, it is sufficient to use the `--initialize-from` CLI parameter to initialize
all models from the same run.
## Training Statistics
To view training statistics, use TensorBoard. For information on launching and

11
docs/Training-SAC.md


Typical Range (Discrete): `32` - `512`
### (Optional) Advanced: Initialize Model Path
`init_path` can be specified to initialize your model from a previous run before starting.
Note that the prior run should have used the same trainer configurations as the current run,
and have been saved with the same version of ML-Agents. You should provide the full path
to the folder where the checkpoints were saved, e.g. `./models/{run-id}/{behavior_name}`.
This option is provided in case you want to initialize different behaviors from different runs;
in most cases, it is sufficient to use the `--initialize-from` CLI parameter to initialize
all models from the same run.
## Training Statistics
To view training statistics, use TensorBoard. For information on launching and

20
ml-agents/mlagents/trainers/learn.py


default=False,
dest="force",
action="store_true",
help="Force-overwrite existing models and summaries for a run-id that has been used "
help="Force-overwrite existing models and summaries for a run ID that has been used "
help="The directory name for model and summary statistics",
help="The run identifier for model and summary statistics.",
)
argparser.add_argument(
"--initialize-from",
metavar="RUN_ID",
default=None,
help="Specify a previously saved run ID from which to initialize the model from. "
"This can be used, for instance, to fine-tune an existing model on a new environment. ",
)
argparser.add_argument(
"--save-freq", default=50000, type=int, help="Frequency at which to save model"

dest="inference",
action="store_true",
help="Run in Python inference mode (don't train). Use with --resume to load a model trained with an "
"existing run-id.",
"existing run ID.",
)
argparser.add_argument(
"--base-port",

seed: int = parser.get_default("seed")
env_path: Optional[str] = parser.get_default("env_path")
run_id: str = parser.get_default("run_id")
initialize_from: str = parser.get_default("initialize_from")
load_model: bool = parser.get_default("load_model")
resume: bool = parser.get_default("resume")
force: bool = parser.get_default("force")

"""
with hierarchical_timer("run_training.setup"):
model_path = f"./models/{options.run_id}"
maybe_init_path = (
f"./models/{options.initialize_from}" if options.initialize_from else None
)
summaries_dir = "./summaries"
port = options.base_port

],
)
handle_existing_directories(
model_path, summaries_dir, options.resume, options.force
model_path, summaries_dir, options.resume, options.force, maybe_init_path
)
tb_writer = TensorboardWriter(summaries_dir, clear_past_data=not options.resume)
gauge_write = GaugeWriter()

not options.inference,
options.resume,
run_seed,
maybe_init_path,
maybe_meta_curriculum,
options.multi_gpu,
)

56
ml-agents/mlagents/trainers/policy/tf_policy.py


if self.use_continuous_act:
self.num_branches = self.brain.vector_action_space_size[0]
self.model_path = trainer_parameters["model_path"]
self.initialize_path = trainer_parameters.get("init_path", None)
self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
self.graph = tf.Graph()
self.sess = tf.Session(

init = tf.global_variables_initializer()
self.sess.run(init)
def _load_graph(self):
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
logger.info("Loading Model for brain {}".format(self.brain.brain_name))
ckpt = tf.train.get_checkpoint_state(self.model_path)
logger.info(
"Loading model for brain {} from {}.".format(
self.brain.brain_name, model_path
)
)
ckpt = tf.train.get_checkpoint_state(model_path)
"--run-id. and that the previous run you are resuming from had the same "
"behavior names.".format(self.model_path)
"--run-id and that the previous run you are loading from had the same "
"behavior names.".format(model_path)
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
try:
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
except tf.errors.NotFoundError:
raise UnityPolicyException(
"The model {0} was found but could not be loaded. Make "
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
"and is using the same trainer configuration as the current run.".format(
model_path
)
)
if reset_global_steps:
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path
)
)
else:
logger.info(
"Resuming training from step {}.".format(self.get_current_step())
)
if self.load:
self._load_graph()
# If there is an initialize path, load from that. Else, load from the set model path.
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
# e.g., resume from an initialize path.
reset_steps = not self.load
if self.initialize_path is not None:
self._load_graph(self.initialize_path, reset_global_steps=reset_steps)
elif self.load:
self._load_graph(self.model_path, reset_global_steps=reset_steps)
else:
self._initialize_graph()

"""
step = self.sess.run(self.global_step)
return step
def _set_step(self, step: int) -> int:
"""
Sets current model step to step without creating additional ops.
:param step: Step to set the current model step to.
:return: The step the model was set to.
"""
current_step = self.get_current_step()
# Increment a positive or negative number of steps.
return self.increment_step(step - current_step)
def increment_step(self, n_steps):
"""

2
ml-agents/mlagents/trainers/tests/test_learn.py


None,
)
handle_dir_mock.assert_called_once_with(
"./models/ppo", "./summaries", False, False
"./models/ppo", "./summaries", False, False, None
)
StatsReporter.writers.clear() # make sure there aren't any writers as added by learn.py

49
ml-agents/mlagents/trainers/tests/test_nn_policy.py


import pytest
import os
from typing import Dict, Any
import numpy as np
from mlagents.tf_utils import tf

NUM_AGENTS = 12
def create_policy_mock(dummy_config, use_rnn, use_discrete, use_visual):
def create_policy_mock(
dummy_config: Dict[str, Any],
use_rnn: bool = False,
use_discrete: bool = True,
use_visual: bool = False,
load: bool = False,
seed: int = 0,
) -> NNPolicy:
mock_brain = mb.setup_mock_brain(
use_discrete,
use_visual,

trainer_parameters = dummy_config
trainer_parameters["keep_checkpoints"] = 3
trainer_parameters["use_recurrent"] = use_rnn
policy = NNPolicy(0, mock_brain, trainer_parameters, False, False)
policy = NNPolicy(seed, mock_brain, trainer_parameters, False, load)
def test_load_save(dummy_config, tmp_path):
path1 = os.path.join(tmp_path, "runid1")
path2 = os.path.join(tmp_path, "runid2")
trainer_params = dummy_config
trainer_params["model_path"] = path1
policy = create_policy_mock(trainer_params)
policy.initialize_or_load()
policy.save_model(2000)
assert len(os.listdir(tmp_path)) > 0
# Try load from this path
policy2 = create_policy_mock(trainer_params, load=True, seed=1)
policy2.initialize_or_load()
_compare_two_policies(policy, policy2)
# Try initialize from path 1
trainer_params["model_path"] = path2
trainer_params["init_path"] = path1
policy3 = create_policy_mock(trainer_params, load=False, seed=2)
policy3.initialize_or_load()
_compare_two_policies(policy2, policy3)
def _compare_two_policies(policy1: NNPolicy, policy2: NNPolicy) -> None:
"""
Make sure two policies have the same output for the same input.
"""
step = mb.create_batchedstep_from_brainparams(policy1.brain, num_agents=1)
run_out1 = policy1.evaluate(step, list(step.agent_id))
run_out2 = policy2.evaluate(step, list(step.agent_id))
np.testing.assert_array_equal(run_out2["log_probs"], run_out1["log_probs"])
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])

12
ml-agents/mlagents/trainers/tests/test_trainer_util.py


trainer_util.handle_existing_directories(model_path, summary_path, True, False)
# Test try to train w/ force - should work
trainer_util.handle_existing_directories(model_path, summary_path, False, True)
# Test initialize option
init_path = os.path.join(tmp_path, "runid2")
with pytest.raises(UnityTrainerException):
trainer_util.handle_existing_directories(
model_path, summary_path, False, True, init_path
)
os.mkdir(init_path)
# Should pass since the directory exists now.
trainer_util.handle_existing_directories(
model_path, summary_path, False, True, init_path
)

27
ml-agents/mlagents/trainers/trainer_util.py


train_model: bool,
load_model: bool,
seed: int,
init_path: str = None,
meta_curriculum: MetaCurriculum = None,
multi_gpu: bool = False,
):

self.model_path = model_path
self.init_path = init_path
self.keep_checkpoints = keep_checkpoints
self.train_model = train_model
self.load_model = load_model

self.load_model,
self.ghost_controller,
self.seed,
self.init_path,
self.meta_curriculum,
self.multi_gpu,
)

load_model: bool,
ghost_controller: GhostController,
seed: int,
init_path: str = None,
meta_curriculum: MetaCurriculum = None,
multi_gpu: bool = False,
) -> Trainer:

:param load_model: Whether to load the model or randomly initialize
:param ghost_controller: The object that coordinates ghost trainers
:param seed: The random seed to use
:param init_path: Path from which to load model, if different from model_path.
:param meta_curriculum: Optional meta_curriculum, used to determine a reward buffer length for PPOTrainer
:return:
"""

trainer_parameters["model_path"] = "{basedir}/{name}".format(
basedir=model_path, name=brain_name
)
if init_path is not None:
trainer_parameters["init_path"] = "{basedir}/{name}".format(
basedir=init_path, name=brain_name
)
trainer_parameters["keep_checkpoints"] = keep_checkpoints
if brain_name in trainer_config:
_brain_key: Any = brain_name

def handle_existing_directories(
model_path: str, summary_path: str, resume: bool, force: bool
model_path: str, summary_path: str, resume: bool, force: bool, init_path: str = None
) -> None:
"""
Validates that if the run_id model exists, we do not overwrite it unless --force is specified.

if model_path_exists:
if not resume and not force:
raise UnityTrainerException(
"Previous data from this run-id was found. "
"Either specify a new run-id, use --resume to resume this run, "
"Previous data from this run ID was found. "
"Either specify a new run ID, use --resume to resume this run, "
"Previous data from this run-id was not found. "
"Previous data from this run ID was not found. "
# Verify init path if specified.
if init_path is not None:
if not os.path.isdir(init_path):
raise UnityTrainerException(
"Could not initialize from {}. "
"Make sure models have already been saved with that run ID.".format(
init_path
)
)
正在加载...
取消
保存