浏览代码

[life improvement] Moving Python files around (#4531)

* Moved components to the tf folder and moved the TrainerFactory to the `trainer` folder

* Addressing comments

* Editing the migrating doc

* fixing test
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
c188781b
共有 31 个文件被更改,包括 303 次插入225 次删除
  1. 16
      docs/Migrating.md
  2. 5
      ml-agents/mlagents/trainers/learn.py
  3. 4
      ml-agents/mlagents/trainers/optimizer/tf_optimizer.py
  4. 2
      ml-agents/mlagents/trainers/ppo/trainer.py
  5. 2
      ml-agents/mlagents/trainers/sac/trainer.py
  6. 2
      ml-agents/mlagents/trainers/tests/check_env_trains.py
  7. 2
      ml-agents/mlagents/trainers/tests/tensorflow/test_bcmodule.py
  8. 2
      ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py
  9. 2
      ml-agents/mlagents/trainers/tests/test_learn.py
  10. 1
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py
  11. 21
      ml-agents/mlagents/trainers/tests/test_trainer_util.py
  12. 1
      ml-agents/mlagents/trainers/trainer/__init__.py
  13. 5
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  14. 2
      ml-agents/mlagents/trainers/trainer_controller.py
  15. 9
      ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity/signal.py
  16. 7
      ml-agents/mlagents/trainers/tf/components/reward_signals/gail/signal.py
  17. 42
      ml-agents/mlagents/trainers/directory_utils.py
  18. 156
      ml-agents/mlagents/trainers/trainer/trainer_factory.py
  19. 13
      ml-agents/mlagents/trainers/tf/components/reward_signals/extrinsic/signal.py
  20. 37
      ml-agents/mlagents/trainers/tf/components/reward_signals/reward_signal_factory.py
  21. 197
      ml-agents/mlagents/trainers/trainer_util.py
  22. 0
      /ml-agents/mlagents/trainers/tf/components/__init__.py
  23. 0
      /ml-agents/mlagents/trainers/tf/components/bc
  24. 0
      /ml-agents/mlagents/trainers/tf/components/reward_signals/__init__.py
  25. 0
      /ml-agents/mlagents/trainers/tf/components/reward_signals/extrinsic/__init__.py
  26. 0
      /ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity
  27. 0
      /ml-agents/mlagents/trainers/tf/components/reward_signals/gail

16
docs/Migrating.md


# Migrating
## Migrating from Release 3 to latest
## Migrating from Release 7 to latest
### Important changes
- Some trainer files were moved. If you were using the `TrainerFactory` class, it was moved to
the `trainers/trainer` folder.
- The `components` folder containing `bc` and `reward_signals` code was moved to the `trainers/tf`
folder
### Steps to Migrate
- Replace calls to `from mlagents.trainers.trainer_util import TrainerFactory` to `from mlagents.trainers.trainer import TrainerFactory`
- Replace calls to `from mlagents.trainers.trainer_util import handle_existing_directories` to `from mlagents.trainers.directory_utils import validate_existing_directories`
- Replace `mlagents.trainers.components` with `mlagents.trainers.tf.components` in your import statements.
## Migrating from Release 3 to Release 7
### Important changes
- The Parameter Randomization feature has been merged with the Curriculum feature. It is now possible to specify a sampler

5
ml-agents/mlagents/trainers/learn.py


from mlagents import tf_utils
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.trainer_util import TrainerFactory, handle_existing_directories
from mlagents.trainers.trainer import TrainerFactory
from mlagents.trainers.directory_utils import validate_existing_directories
from mlagents.trainers.stats import (
TensorboardWriter,
StatsReporter,

run_logs_dir = os.path.join(write_path, "run_logs")
port: Optional[int] = env_settings.base_port
# Check if directory exists
handle_existing_directories(
validate_existing_directories(
write_path,
checkpoint_settings.resume,
checkpoint_settings.force,

4
ml-agents/mlagents/trainers/optimizer/tf_optimizer.py


from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.optimizer import Optimizer
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.components.reward_signals.reward_signal_factory import (
from mlagents.trainers.tf.components.reward_signals.reward_signal_factory import (
from mlagents.trainers.components.bc.module import BCModule
from mlagents.trainers.tf.components.bc.module import BCModule
class TFOptimizer(Optimizer): # pylint: disable=W0223

2
ml-agents/mlagents/trainers/ppo/trainer.py


from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, PPOSettings, FrameworkType
from mlagents.trainers.components.reward_signals import RewardSignal
from mlagents.trainers.tf.components.reward_signals import RewardSignal
from mlagents import torch_utils
if torch_utils.is_available():

2
ml-agents/mlagents/trainers/sac/trainer.py


from mlagents.trainers.trajectory import Trajectory, SplitObservations
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, SACSettings, FrameworkType
from mlagents.trainers.components.reward_signals import RewardSignal
from mlagents.trainers.tf.components.reward_signals import RewardSignal
from mlagents import torch_utils
if torch_utils.is_available():

2
ml-agents/mlagents/trainers/tests/check_env_trains.py


import numpy as np
from typing import Dict
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.trainer_util import TrainerFactory
from mlagents.trainers.trainer import TrainerFactory
from mlagents.trainers.simple_env_manager import SimpleEnvManager
from mlagents.trainers.stats import StatsReporter, StatsWriter, StatsSummary
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager

2
ml-agents/mlagents/trainers/tests/tensorflow/test_bcmodule.py


import numpy as np
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.components.bc.module import BCModule
from mlagents.trainers.tf.components.bc.module import BCModule
from mlagents.trainers.settings import (
TrainerSettings,
BehavioralCloningSettings,

2
ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py


RecordEnvironment,
)
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.trainer_util import TrainerFactory
from mlagents.trainers.trainer import TrainerFactory
from mlagents.trainers.simple_env_manager import SimpleEnvManager
from mlagents.trainers.demo_loader import write_demo
from mlagents.trainers.stats import StatsReporter, StatsWriter, StatsSummary

2
ml-agents/mlagents/trainers/tests/test_learn.py


@patch("mlagents.trainers.learn.write_timing_tree")
@patch("mlagents.trainers.learn.write_run_options")
@patch("mlagents.trainers.learn.handle_existing_directories")
@patch("mlagents.trainers.learn.validate_existing_directories")
@patch("mlagents.trainers.learn.TrainerFactory")
@patch("mlagents.trainers.learn.SubprocessEnvManager")
@patch("mlagents.trainers.learn.create_environment_factory")

1
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


trainer_mock.write_tensorboard_text = MagicMock()
tc = basic_trainer_controller
tc.initialize_trainers = MagicMock()
tc.trainers = {"testbrain": trainer_mock}
tc.advance = MagicMock()
tc.trainers["testbrain"].get_step = 0

21
ml-agents/mlagents/trainers/tests/test_trainer_util.py


import os
from unittest.mock import patch
from mlagents.trainers import trainer_util
from mlagents.trainers.trainer import TrainerFactory
from mlagents.trainers.cli_utils import load_config, _load_config
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.trainers.exception import TrainerConfigError, UnityTrainerException

from mlagents.trainers.directory_utils import validate_existing_directories
@pytest.fixture

assert artifact_path == os.path.join(output_path, brain_name)
with patch.object(PPOTrainer, "__init__", mock_constructor):
trainer_factory = trainer_util.TrainerFactory(
trainer_factory = TrainerFactory(
trainer_config=base_config,
output_path=output_path,
train_model=train_model,

brain_name = "testbrain"
no_default_config = RunOptions().behaviors
trainer_factory = trainer_util.TrainerFactory(
trainer_factory = TrainerFactory(
trainer_config=no_default_config,
output_path="output_path",
train_model=True,

def test_existing_directories(tmp_path):
output_path = os.path.join(tmp_path, "runid")
# Test fresh new unused path - should do nothing.
trainer_util.handle_existing_directories(output_path, False, False)
validate_existing_directories(output_path, False, False)
trainer_util.handle_existing_directories(output_path, True, False)
validate_existing_directories(output_path, True, False)
trainer_util.handle_existing_directories(output_path, False, False)
validate_existing_directories(output_path, False, False)
trainer_util.handle_existing_directories(output_path, True, False)
validate_existing_directories(output_path, True, False)
trainer_util.handle_existing_directories(output_path, False, True)
validate_existing_directories(output_path, False, True)
trainer_util.handle_existing_directories(output_path, False, True, init_path)
validate_existing_directories(output_path, False, True, init_path)
trainer_util.handle_existing_directories(output_path, False, True, init_path)
validate_existing_directories(output_path, False, True, init_path)

1
ml-agents/mlagents/trainers/trainer/__init__.py


from mlagents.trainers.trainer.trainer import Trainer # noqa
from mlagents.trainers.trainer.trainer_factory import TrainerFactory # noqa

5
ml-agents/mlagents/trainers/trainer/rl_trainer.py


from mlagents.trainers.optimizer import Optimizer
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.components.reward_signals import RewardSignalResult, RewardSignal
from mlagents.trainers.tf.components.reward_signals import (
RewardSignalResult,
RewardSignal,
)
from mlagents_envs.timers import hierarchical_timer
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.policy.policy import Policy

2
ml-agents/mlagents/trainers/trainer_controller.py


)
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.trainer_util import TrainerFactory
from mlagents.trainers.trainer import TrainerFactory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.agent_processor import AgentManager
from mlagents.tf_utils.globals import get_rank

9
ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity/signal.py


import numpy as np
from mlagents.tf_utils import tf
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
from mlagents.trainers.components.reward_signals.curiosity.model import CuriosityModel
from mlagents.trainers.tf.components.reward_signals import (
RewardSignal,
RewardSignalResult,
)
from mlagents.trainers.tf.components.reward_signals.curiosity.model import (
CuriosityModel,
)
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.settings import CuriositySettings

7
ml-agents/mlagents/trainers/tf/components/reward_signals/gail/signal.py


import numpy as np
from mlagents.tf_utils import tf
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
from mlagents.trainers.tf.components.reward_signals import (
RewardSignal,
RewardSignalResult,
)
from .model import GAILModel
from mlagents.trainers.tf.components.reward_signals.gail.model import GAILModel
from mlagents.trainers.demo_loader import demo_to_buffer
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.settings import GAILSettings

42
ml-agents/mlagents/trainers/directory_utils.py


import os
from mlagents.trainers.exception import UnityTrainerException
def validate_existing_directories(
output_path: str, resume: bool, force: bool, init_path: str = None
) -> None:
"""
Validates that if the run_id model exists, we do not overwrite it unless --force is specified.
Throws an exception if resume isn't specified and run_id exists. Throws an exception
if --resume is specified and run-id was not found.
:param model_path: The model path specified.
:param summary_path: The summary path to be used.
:param resume: Whether or not the --resume flag was passed.
:param force: Whether or not the --force flag was passed.
"""
output_path_exists = os.path.isdir(output_path)
if output_path_exists:
if not resume and not force:
raise UnityTrainerException(
"Previous data from this run ID was found. "
"Either specify a new run ID, use --resume to resume this run, "
"or use the --force parameter to overwrite existing data."
)
else:
if resume:
raise UnityTrainerException(
"Previous data from this run ID was not found. "
"Train a new run by removing the --resume flag."
)
# Verify init path if specified.
if init_path is not None:
if not os.path.isdir(init_path):
raise UnityTrainerException(
"Could not initialize from {}. "
"Make sure models have already been saved with that run ID.".format(
init_path
)
)

156
ml-agents/mlagents/trainers/trainer/trainer_factory.py


import os
from typing import Dict
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.settings import TrainerSettings, TrainerType, FrameworkType
logger = get_logger(__name__)
class TrainerFactory:
def __init__(
self,
trainer_config: Dict[str, TrainerSettings],
output_path: str,
train_model: bool,
load_model: bool,
seed: int,
param_manager: EnvironmentParameterManager,
init_path: str = None,
multi_gpu: bool = False,
force_torch: bool = False,
):
"""
The TrainerFactory generates the Trainers based on the configuration passed as
input.
:param trainer_config: A dictionary from behavior name to TrainerSettings
:param output_path: The path to the directory where the artifacts generated by
the trainer will be saved.
:param train_model: If True, the Trainers will train the model and if False,
only perform inference.
:param load_model: If True, the Trainer will load neural networks weights from
the previous run.
:param seed: The seed of the Trainers. Dictates how the neural networks will be
initialized.
:param param_manager: The EnvironmentParameterManager that will dictate when/if
the EnvironmentParameters must change.
:param init_path: Path from which to load model.
:param multi_gpu: If True, multi-gpu will be used. (currently not available)
:param force_torch: If True, the Trainers will all use the PyTorch framework
instead of the TensorFlow framework.
"""
self.trainer_config = trainer_config
self.output_path = output_path
self.init_path = init_path
self.train_model = train_model
self.load_model = load_model
self.seed = seed
self.param_manager = param_manager
self.multi_gpu = multi_gpu
self.ghost_controller = GhostController()
self._force_torch = force_torch
def generate(self, behavior_name: str) -> Trainer:
if behavior_name not in self.trainer_config.keys():
logger.warning(
f"Behavior name {behavior_name} does not match any behaviors specified"
f"in the trainer configuration file: {sorted(self.trainer_config.keys())}"
)
trainer_settings = self.trainer_config[behavior_name]
if self._force_torch:
trainer_settings.framework = FrameworkType.PYTORCH
return TrainerFactory._initialize_trainer(
trainer_settings,
behavior_name,
self.output_path,
self.train_model,
self.load_model,
self.ghost_controller,
self.seed,
self.param_manager,
self.init_path,
self.multi_gpu,
)
@staticmethod
def _initialize_trainer(
trainer_settings: TrainerSettings,
brain_name: str,
output_path: str,
train_model: bool,
load_model: bool,
ghost_controller: GhostController,
seed: int,
param_manager: EnvironmentParameterManager,
init_path: str = None,
multi_gpu: bool = False,
) -> Trainer:
"""
Initializes a trainer given a provided trainer configuration and brain parameters, as well as
some general training session options.
:param trainer_settings: Original trainer configuration loaded from YAML
:param brain_name: Name of the brain to be associated with trainer
:param output_path: Path to save the model and summary statistics
:param keep_checkpoints: How many model checkpoints to keep
:param train_model: Whether to train the model (vs. run inference)
:param load_model: Whether to load the model or randomly initialize
:param ghost_controller: The object that coordinates ghost trainers
:param seed: The random seed to use
:param param_manager: EnvironmentParameterManager, used to determine a reward buffer length for PPOTrainer
:param init_path: Path from which to load model, if different from model_path.
:return:
"""
trainer_artifact_path = os.path.join(output_path, brain_name)
if init_path is not None:
trainer_settings.init_path = os.path.join(init_path, brain_name)
min_lesson_length = param_manager.get_minimum_reward_buffer_size(brain_name)
trainer: Trainer = None # type: ignore # will be set to one of these, or raise
trainer_type = trainer_settings.trainer_type
if trainer_type == TrainerType.PPO:
trainer = PPOTrainer(
brain_name,
min_lesson_length,
trainer_settings,
train_model,
load_model,
seed,
trainer_artifact_path,
)
elif trainer_type == TrainerType.SAC:
trainer = SACTrainer(
brain_name,
min_lesson_length,
trainer_settings,
train_model,
load_model,
seed,
trainer_artifact_path,
)
else:
raise TrainerConfigError(
f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}'
)
if trainer_settings.self_play is not None:
trainer = GhostTrainer(
trainer,
brain_name,
ghost_controller,
min_lesson_length,
trainer_settings,
train_model,
trainer_artifact_path,
)
return trainer

13
ml-agents/mlagents/trainers/tf/components/reward_signals/extrinsic/signal.py


import numpy as np
from mlagents.trainers.tf.components.reward_signals import (
RewardSignal,
RewardSignalResult,
)
from mlagents.trainers.buffer import AgentBuffer
class ExtrinsicRewardSignal(RewardSignal):
def evaluate_batch(self, mini_batch: AgentBuffer) -> RewardSignalResult:
env_rews = np.array(mini_batch["environment_rewards"], dtype=np.float32)
return RewardSignalResult(self.strength * env_rews, env_rews)

37
ml-agents/mlagents/trainers/tf/components/reward_signals/reward_signal_factory.py


from typing import Dict, Type
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.tf.components.reward_signals import RewardSignal
from mlagents.trainers.tf.components.reward_signals.extrinsic.signal import (
ExtrinsicRewardSignal,
)
from mlagents.trainers.tf.components.reward_signals.gail.signal import GAILRewardSignal
from mlagents.trainers.tf.components.reward_signals.curiosity.signal import (
CuriosityRewardSignal,
)
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType
NAME_TO_CLASS: Dict[RewardSignalType, Type[RewardSignal]] = {
RewardSignalType.EXTRINSIC: ExtrinsicRewardSignal,
RewardSignalType.CURIOSITY: CuriosityRewardSignal,
RewardSignalType.GAIL: GAILRewardSignal,
}
def create_reward_signal(
policy: TFPolicy, name: RewardSignalType, settings: RewardSignalSettings
) -> RewardSignal:
"""
Creates a reward signal class based on the name and config entry provided as a dict.
:param policy: The policy class which the reward will be applied to.
:param name: The name of the reward signal
:param config_entry: The config entries for that reward signal
:return: The reward signal class instantiated
"""
rcls = NAME_TO_CLASS.get(name)
if not rcls:
raise UnityTrainerException(f"Unknown reward signal type {name}")
class_inst = rcls(policy, settings)
return class_inst

197
ml-agents/mlagents/trainers/trainer_util.py


import os
from typing import Dict
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
from mlagents.trainers.exception import TrainerConfigError
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.settings import TrainerSettings, TrainerType, FrameworkType
logger = get_logger(__name__)
class TrainerFactory:
def __init__(
self,
trainer_config: Dict[str, TrainerSettings],
output_path: str,
train_model: bool,
load_model: bool,
seed: int,
param_manager: EnvironmentParameterManager,
init_path: str = None,
multi_gpu: bool = False,
force_torch: bool = False,
):
"""
The TrainerFactory generates the Trainers based on the configuration passed as
input.
:param trainer_config: A dictionary from behavior name to TrainerSettings
:param output_path: The path to the directory where the artifacts generated by
the trainer will be saved.
:param train_model: If True, the Trainers will train the model and if False,
only perform inference.
:param load_model: If True, the Trainer will load neural networks weights from
the previous run.
:param seed: The seed of the Trainers. Dictates how the neural networks will be
initialized.
:param param_manager: The EnvironmentParameterManager that will dictate when/if
the EnvironmentParameters must change.
:param init_path: Path from which to load model.
:param multi_gpu: If True, multi-gpu will be used. (currently not available)
:param force_torch: If True, the Trainers will all use the PyTorch framework
instead of the TensorFlow framework.
"""
self.trainer_config = trainer_config
self.output_path = output_path
self.init_path = init_path
self.train_model = train_model
self.load_model = load_model
self.seed = seed
self.param_manager = param_manager
self.multi_gpu = multi_gpu
self.ghost_controller = GhostController()
self._force_torch = force_torch
def generate(self, behavior_name: str) -> Trainer:
if behavior_name not in self.trainer_config.keys():
logger.warning(
f"Behavior name {behavior_name} does not match any behaviors specified"
f"in the trainer configuration file: {sorted(self.trainer_config.keys())}"
)
trainer_settings = self.trainer_config[behavior_name]
if self._force_torch:
trainer_settings.framework = FrameworkType.PYTORCH
return initialize_trainer(
trainer_settings,
behavior_name,
self.output_path,
self.train_model,
self.load_model,
self.ghost_controller,
self.seed,
self.param_manager,
self.init_path,
self.multi_gpu,
)
def initialize_trainer(
trainer_settings: TrainerSettings,
brain_name: str,
output_path: str,
train_model: bool,
load_model: bool,
ghost_controller: GhostController,
seed: int,
param_manager: EnvironmentParameterManager,
init_path: str = None,
multi_gpu: bool = False,
) -> Trainer:
"""
Initializes a trainer given a provided trainer configuration and brain parameters, as well as
some general training session options.
:param trainer_settings: Original trainer configuration loaded from YAML
:param brain_name: Name of the brain to be associated with trainer
:param output_path: Path to save the model and summary statistics
:param keep_checkpoints: How many model checkpoints to keep
:param train_model: Whether to train the model (vs. run inference)
:param load_model: Whether to load the model or randomly initialize
:param ghost_controller: The object that coordinates ghost trainers
:param seed: The random seed to use
:param param_manager: EnvironmentParameterManager, used to determine a reward buffer length for PPOTrainer
:param init_path: Path from which to load model, if different from model_path.
:return:
"""
trainer_artifact_path = os.path.join(output_path, brain_name)
if init_path is not None:
trainer_settings.init_path = os.path.join(init_path, brain_name)
min_lesson_length = param_manager.get_minimum_reward_buffer_size(brain_name)
trainer: Trainer = None # type: ignore # will be set to one of these, or raise
trainer_type = trainer_settings.trainer_type
if trainer_type == TrainerType.PPO:
trainer = PPOTrainer(
brain_name,
min_lesson_length,
trainer_settings,
train_model,
load_model,
seed,
trainer_artifact_path,
)
elif trainer_type == TrainerType.SAC:
trainer = SACTrainer(
brain_name,
min_lesson_length,
trainer_settings,
train_model,
load_model,
seed,
trainer_artifact_path,
)
else:
raise TrainerConfigError(
f'The trainer config contains an unknown trainer type "{trainer_type}" for brain {brain_name}'
)
if trainer_settings.self_play is not None:
trainer = GhostTrainer(
trainer,
brain_name,
ghost_controller,
min_lesson_length,
trainer_settings,
train_model,
trainer_artifact_path,
)
return trainer
def handle_existing_directories(
output_path: str, resume: bool, force: bool, init_path: str = None
) -> None:
"""
Validates that if the run_id model exists, we do not overwrite it unless --force is specified.
Throws an exception if resume isn't specified and run_id exists. Throws an exception
if --resume is specified and run-id was not found.
:param model_path: The model path specified.
:param summary_path: The summary path to be used.
:param resume: Whether or not the --resume flag was passed.
:param force: Whether or not the --force flag was passed.
"""
output_path_exists = os.path.isdir(output_path)
if output_path_exists:
if not resume and not force:
raise UnityTrainerException(
"Previous data from this run ID was found. "
"Either specify a new run ID, use --resume to resume this run, "
"or use the --force parameter to overwrite existing data."
)
else:
if resume:
raise UnityTrainerException(
"Previous data from this run ID was not found. "
"Train a new run by removing the --resume flag."
)
# Verify init path if specified.
if init_path is not None:
if not os.path.isdir(init_path):
raise UnityTrainerException(
"Could not initialize from {}. "
"Make sure models have already been saved with that run ID.".format(
init_path
)
)

/ml-agents/mlagents/trainers/components/__init__.py → /ml-agents/mlagents/trainers/tf/components/__init__.py

/ml-agents/mlagents/trainers/components/bc → /ml-agents/mlagents/trainers/tf/components/bc

/ml-agents/mlagents/trainers/components/reward_signals/__init__.py → /ml-agents/mlagents/trainers/tf/components/reward_signals/__init__.py

/ml-agents/mlagents/trainers/components/reward_signals/extrinsic/__init__.py → /ml-agents/mlagents/trainers/tf/components/reward_signals/extrinsic/__init__.py

/ml-agents/mlagents/trainers/components/reward_signals/curiosity → /ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity

/ml-agents/mlagents/trainers/components/reward_signals/gail → /ml-agents/mlagents/trainers/tf/components/reward_signals/gail

正在加载...
取消
保存