|
|
|
|
|
|
) |
|
|
|
from mlagents.trainers.cli_utils import parser, DetectDefault |
|
|
|
from mlagents_envs.environment import UnityEnvironment |
|
|
|
from mlagents.trainers.settings import RunOptions |
|
|
|
from mlagents.trainers.settings import RunOptions, FrameworkType |
|
|
|
|
|
|
|
from mlagents.trainers.training_status import GlobalTrainingStatus |
|
|
|
from mlagents_envs.base_env import BaseEnv |
|
|
|
|
|
|
options.environment_parameters, run_seed, restore=checkpoint_settings.resume |
|
|
|
) |
|
|
|
|
|
|
|
force_torch = "torch" in DetectDefault.non_default_args |
|
|
|
if force_torch or any( |
|
|
|
behavior.framework == FrameworkType.TENSORFLOW |
|
|
|
for behavior in options.behaviors.values() |
|
|
|
): |
|
|
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
|
|
|
|
|
|
trainer_factory = TrainerFactory( |
|
|
|
trainer_config=options.behaviors, |
|
|
|
output_path=write_path, |
|
|
|
|
|
|
param_manager=env_parameter_manager, |
|
|
|
init_path=maybe_init_path, |
|
|
|
multi_gpu=False, |
|
|
|
force_torch="torch" in DetectDefault.non_default_args, |
|
|
|
force_torch=force_torch, |
|
|
|
) |
|
|
|
# Create controller and begin training. |
|
|
|
tc = TrainerController( |
|
|
|