浏览代码

Merge branch 'master' into develop-saver-name

/MLA-1734-demo-provider
Ruo-Ping Dong 4 年前
当前提交
a74c904a
共有 4 个文件被更改,包括 146 次插入20 次删除
  1. 19
      ml-agents/mlagents/trainers/learn.py
  2. 4
      ml-agents/mlagents/trainers/settings.py
  3. 38
      ml-agents/mlagents/trainers/trainer_util.py
  4. 105
      ml-agents/mlagents/trainers/tests/torch/test_saver.py

19
ml-agents/mlagents/trainers/learn.py


GaugeWriter,
ConsoleWriter,
)
from mlagents.trainers.cli_utils import parser
from mlagents.trainers.cli_utils import parser, DetectDefault
from mlagents_envs.environment import UnityEnvironment
from mlagents.trainers.settings import RunOptions

)
trainer_factory = TrainerFactory(
options.behaviors,
write_path,
not checkpoint_settings.inference,
checkpoint_settings.resume,
run_seed,
env_parameter_manager,
maybe_init_path,
False,
trainer_config=options.behaviors,
output_path=write_path,
train_model=not checkpoint_settings.inference,
load_model=checkpoint_settings.resume,
seed=run_seed,
param_manager=env_parameter_manager,
init_path=maybe_init_path,
multi_gpu=False,
force_torch="torch" in DetectDefault.non_default_args,
)
# Create controller and begin training.
tc = TrainerController(

4
ml-agents/mlagents/trainers/settings.py


else: # Base options
configured_dict[key] = val
# Apply --torch retroactively
if "torch" in DetectDefault.non_default_args:
for trainer_set in final_runoptions.behaviors.values():
trainer_set.framework = FrameworkType.PYTORCH
return final_runoptions
@staticmethod

38
ml-agents/mlagents/trainers/trainer_util.py


from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.settings import TrainerSettings, TrainerType
from mlagents.trainers.settings import TrainerSettings, TrainerType, FrameworkType
logger = get_logger(__name__)

param_manager: EnvironmentParameterManager,
init_path: str = None,
multi_gpu: bool = False,
force_torch: bool = False,
"""
The TrainerFactory generates the Trainers based on the configuration passed as
input.
:param trainer_config: A dictionary from behavior name to TrainerSettings
:param output_path: The path to the directory where the artifacts generated by
the trainer will be saved.
:param train_model: If True, the Trainers will train the model and if False,
only perform inference.
:param load_model: If True, the Trainer will load neural networks weights from
the previous run.
:param seed: The seed of the Trainers. Dictates how the neural networks will be
initialized.
:param param_manager: The EnvironmentParameterManager that will dictate when/if
the EnvironmentParameters must change.
:param init_path: Path from which to load model.
:param multi_gpu: If True, multi-gpu will be used. (currently not available)
:param force_torch: If True, the Trainers will all use the PyTorch framework
instead of the TensorFlow framework.
"""
self.trainer_config = trainer_config
self.output_path = output_path
self.init_path = init_path

self.param_manager = param_manager
self.multi_gpu = multi_gpu
self.ghost_controller = GhostController()
self._force_torch = force_torch
def generate(self, brain_name: str) -> Trainer:
if brain_name not in self.trainer_config.keys():
def generate(self, behavior_name: str) -> Trainer:
if behavior_name not in self.trainer_config.keys():
f"Behavior name {brain_name} does not match any behaviors specified in the trainer configuration file:"
f"{sorted(self.trainer_config.keys())}"
f"Behavior name {behavior_name} does not match any behaviors specified"
f"in the trainer configuration file: {sorted(self.trainer_config.keys())}"
trainer_settings = self.trainer_config[behavior_name]
if self._force_torch:
trainer_settings.framework = FrameworkType.PYTORCH
self.trainer_config[brain_name],
brain_name,
trainer_settings,
behavior_name,
self.output_path,
self.train_model,
self.load_model,

105
ml-agents/mlagents/trainers/tests/torch/test_saver.py


import pytest
from unittest import mock
import os
import numpy as np
import torch
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
from mlagents.trainers.saver.torch_saver import TorchSaver
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.torch.test_policy import create_policy_mock
def test_register(tmp_path):
trainer_params = TrainerSettings()
saver = TorchSaver(trainer_params, tmp_path)
opt = mock.Mock(spec=TorchPPOOptimizer)
opt.get_modules = mock.Mock(return_value={})
saver.register(opt)
assert saver.policy is None
trainer_params = TrainerSettings()
policy = create_policy_mock(trainer_params)
opt.get_modules = mock.Mock(return_value={})
saver.register(policy)
assert saver.policy is not None
def test_load_save(tmp_path):
path1 = os.path.join(tmp_path, "runid1")
path2 = os.path.join(tmp_path, "runid2")
trainer_params = TrainerSettings()
policy = create_policy_mock(trainer_params)
saver = TorchSaver(trainer_params, path1)
saver.register(policy)
saver.initialize_or_load(policy)
policy.set_step(2000)
mock_brain_name = "MockBrain"
saver.save_checkpoint(mock_brain_name, 2000)
assert len(os.listdir(tmp_path)) > 0
# Try load from this path
saver2 = TorchSaver(trainer_params, path1, load=True)
policy2 = create_policy_mock(trainer_params)
saver2.register(policy2)
saver2.initialize_or_load(policy2)
_compare_two_policies(policy, policy2)
assert policy2.get_current_step() == 2000
# Try initialize from path 1
trainer_params.init_path = path1
saver3 = TorchSaver(trainer_params, path2)
policy3 = create_policy_mock(trainer_params)
saver3.register(policy3)
saver3.initialize_or_load(policy3)
_compare_two_policies(policy2, policy3)
# Assert that the steps are 0.
assert policy3.get_current_step() == 0
# TorchPolicy.evalute() returns log_probs instead of all_log_probs like tf does.
# resulting in indeterministic results for testing.
# So here use sample_actions instead.
def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None:
"""
Make sure two policies have the same output for the same input.
"""
decision_step, _ = mb.create_steps_from_behavior_spec(
policy1.behavior_spec, num_agents=1
)
vec_vis_obs, masks = policy1._split_decision_step(decision_step)
vec_obs = [torch.as_tensor(vec_vis_obs.vector_observations)]
vis_obs = [torch.as_tensor(vis_ob) for vis_ob in vec_vis_obs.visual_observations]
memories = torch.as_tensor(
policy1.retrieve_memories(list(decision_step.agent_id))
).unsqueeze(0)
with torch.no_grad():
_, log_probs1, _, _, _ = policy1.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
)
_, log_probs2, _, _, _ = policy2.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
)
np.testing.assert_array_equal(log_probs1, log_probs2)
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"])
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"])
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_checkpoint_conversion(tmpdir, rnn, visual, discrete):
dummy_config = TrainerSettings()
model_path = os.path.join(tmpdir, "Mock_Brain")
policy = create_policy_mock(
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
)
trainer_params = TrainerSettings()
saver = TorchSaver(trainer_params, model_path)
saver.register(policy)
saver.save_checkpoint("Mock_Brain", 100)
assert os.path.isfile(model_path + "/Mock_Brain-100.onnx")
正在加载...
取消
保存