浏览代码

[refactor] Remove nonfunctional `output_path` option from TrainerSettings (#4087)

/MLA-1734-demo-provider
GitHub 5 年前
当前提交
8a49e8e0
共有 68 个文件被更改,包括 67 次插入112 次删除
  1. 1
      config/imitation/CrawlerStatic.yaml
  2. 1
      config/imitation/FoodCollector.yaml
  3. 1
      config/imitation/Hallway.yaml
  4. 1
      config/imitation/PushBlock.yaml
  5. 1
      config/ppo/3DBall.yaml
  6. 1
      config/ppo/3DBallHard.yaml
  7. 1
      config/ppo/3DBall_randomize.yaml
  8. 1
      config/ppo/Basic.yaml
  9. 1
      config/ppo/Bouncer.yaml
  10. 1
      config/ppo/CrawlerDynamic.yaml
  11. 1
      config/ppo/CrawlerStatic.yaml
  12. 1
      config/ppo/FoodCollector.yaml
  13. 1
      config/ppo/GridWorld.yaml
  14. 1
      config/ppo/Hallway.yaml
  15. 1
      config/ppo/PushBlock.yaml
  16. 1
      config/ppo/Pyramids.yaml
  17. 1
      config/ppo/Reacher.yaml
  18. 1
      config/ppo/SoccerTwos.yaml
  19. 2
      config/ppo/StrikersVsGoalie.yaml
  20. 1
      config/ppo/Tennis.yaml
  21. 1
      config/ppo/VisualHallway.yaml
  22. 1
      config/ppo/VisualPushBlock.yaml
  23. 1
      config/ppo/VisualPyramids.yaml
  24. 1
      config/ppo/WalkerDynamic.yaml
  25. 1
      config/ppo/WalkerStatic.yaml
  26. 2
      config/ppo/WallJump.yaml
  27. 2
      config/ppo/WallJump_curriculum.yaml
  28. 1
      config/ppo/WormDynamic.yaml
  29. 1
      config/ppo/WormStatic.yaml
  30. 1
      config/sac/3DBall.yaml
  31. 1
      config/sac/3DBallHard.yaml
  32. 1
      config/sac/Basic.yaml
  33. 1
      config/sac/Bouncer.yaml
  34. 1
      config/sac/CrawlerDynamic.yaml
  35. 1
      config/sac/CrawlerStatic.yaml
  36. 1
      config/sac/FoodCollector.yaml
  37. 1
      config/sac/GridWorld.yaml
  38. 1
      config/sac/Hallway.yaml
  39. 1
      config/sac/PushBlock.yaml
  40. 1
      config/sac/Pyramids.yaml
  41. 1
      config/sac/Reacher.yaml
  42. 1
      config/sac/Tennis.yaml
  43. 1
      config/sac/VisualHallway.yaml
  44. 1
      config/sac/VisualPushBlock.yaml
  45. 1
      config/sac/VisualPyramids.yaml
  46. 1
      config/sac/WalkerDynamic.yaml
  47. 1
      config/sac/WalkerStatic.yaml
  48. 2
      config/sac/WallJump.yaml
  49. 1
      config/sac/WormDynamic.yaml
  50. 1
      config/sac/WormStatic.yaml
  51. 6
      ml-agents/mlagents/trainers/ghost/trainer.py
  52. 1
      ml-agents/mlagents/trainers/learn.py
  53. 4
      ml-agents/mlagents/trainers/policy/nn_policy.py
  54. 5
      ml-agents/mlagents/trainers/policy/tf_policy.py
  55. 7
      ml-agents/mlagents/trainers/ppo/trainer.py
  56. 15
      ml-agents/mlagents/trainers/sac/trainer.py
  57. 1
      ml-agents/mlagents/trainers/settings.py
  58. 8
      ml-agents/mlagents/trainers/tests/test_barracuda_converter.py
  59. 2
      ml-agents/mlagents/trainers/tests/test_bcmodule.py
  60. 16
      ml-agents/mlagents/trainers/tests/test_nn_policy.py
  61. 6
      ml-agents/mlagents/trainers/tests/test_policy.py
  62. 2
      ml-agents/mlagents/trainers/tests/test_ppo.py
  63. 2
      ml-agents/mlagents/trainers/tests/test_reward_signals.py
  64. 10
      ml-agents/mlagents/trainers/tests/test_sac.py
  65. 1
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  66. 14
      ml-agents/mlagents/trainers/tests/test_trainer_util.py
  67. 12
      ml-agents/mlagents/trainers/trainer/trainer.py
  68. 13
      ml-agents/mlagents/trainers/trainer_util.py

1
config/imitation/CrawlerStatic.yaml


use_actions: false
use_vail: false
demo_path: Project/Assets/ML-Agents/Examples/Crawler/Demos/ExpertCrawlerSta.demo
output_path: default
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 1000

1
config/imitation/FoodCollector.yaml


use_actions: false
use_vail: false
demo_path: Project/Assets/ML-Agents/Examples/FoodCollector/Demos/ExpertFood.demo
output_path: default
keep_checkpoints: 5
max_steps: 2000000
time_horizon: 64

1
config/imitation/Hallway.yaml


use_actions: false
use_vail: false
demo_path: Project/Assets/ML-Agents/Examples/Hallway/Demos/ExpertHallway.demo
output_path: default
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 64

1
config/imitation/PushBlock.yaml


use_actions: false
use_vail: false
demo_path: Project/Assets/ML-Agents/Examples/PushBlock/Demos/ExpertPush.demo
output_path: default
keep_checkpoints: 5
max_steps: 15000000
time_horizon: 64

1
config/ppo/3DBall.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 500000
time_horizon: 1000

1
config/ppo/3DBallHard.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 5000000
time_horizon: 1000

1
config/ppo/3DBall_randomize.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 500000
time_horizon: 1000

1
config/ppo/Basic.yaml


extrinsic:
gamma: 0.9
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 500000
time_horizon: 3

1
config/ppo/Bouncer.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 4000000
time_horizon: 64

1
config/ppo/CrawlerDynamic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 1000

1
config/ppo/CrawlerStatic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 1000

1
config/ppo/FoodCollector.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 2000000
time_horizon: 64

1
config/ppo/GridWorld.yaml


extrinsic:
gamma: 0.9
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 500000
time_horizon: 5

1
config/ppo/Hallway.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 64

1
config/ppo/PushBlock.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 2000000
time_horizon: 64

1
config/ppo/Pyramids.yaml


strength: 0.02
encoding_size: 256
learning_rate: 0.0003
output_path: default
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 128

1
config/ppo/Reacher.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 20000000
time_horizon: 1000

1
config/ppo/SoccerTwos.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 50000000
time_horizon: 1000

2
config/ppo/StrikersVsGoalie.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 50000000
time_horizon: 1000

extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 50000000
time_horizon: 1000

1
config/ppo/Tennis.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 50000000
time_horizon: 1000

1
config/ppo/VisualHallway.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 64

1
config/ppo/VisualPushBlock.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 3000000
time_horizon: 64

1
config/ppo/VisualPyramids.yaml


strength: 0.01
encoding_size: 256
learning_rate: 0.0003
output_path: default
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 128

1
config/ppo/WalkerDynamic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 25000000
time_horizon: 1000

1
config/ppo/WalkerStatic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 20000000
time_horizon: 1000

2
config/ppo/WallJump.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 20000000
time_horizon: 128

extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 5000000
time_horizon: 128

2
config/ppo/WallJump_curriculum.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 20000000
time_horizon: 128

extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 5000000
time_horizon: 128

1
config/ppo/WormDynamic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 3500000
time_horizon: 1000

1
config/ppo/WormStatic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 3500000
time_horizon: 1000

1
config/sac/3DBall.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 500000
time_horizon: 1000

1
config/sac/3DBallHard.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 500000
time_horizon: 1000

1
config/sac/Basic.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 500000
time_horizon: 10

1
config/sac/Bouncer.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 1000000
time_horizon: 64

1
config/sac/CrawlerDynamic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 5000000
time_horizon: 1000

1
config/sac/CrawlerStatic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 3000000
time_horizon: 1000

1
config/sac/FoodCollector.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 2000000
time_horizon: 64

1
config/sac/GridWorld.yaml


extrinsic:
gamma: 0.9
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 500000
time_horizon: 5

1
config/sac/Hallway.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 5000000
time_horizon: 64

1
config/sac/PushBlock.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 2000000
time_horizon: 64

1
config/sac/Pyramids.yaml


use_actions: true
use_vail: false
demo_path: Project/Assets/ML-Agents/Examples/Pyramids/Demos/ExpertPyramid.demo
output_path: default
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 128

1
config/sac/Reacher.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 20000000
time_horizon: 1000

1
config/sac/Tennis.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 20000000
time_horizon: 64

1
config/sac/VisualHallway.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 64

1
config/sac/VisualPushBlock.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 3000000
time_horizon: 64

1
config/sac/VisualPyramids.yaml


use_actions: true
use_vail: false
demo_path: Project/Assets/ML-Agents/Examples/Pyramids/Demos/ExpertPyramid.demo
output_path: default
keep_checkpoints: 5
max_steps: 10000000
time_horizon: 128

1
config/sac/WalkerDynamic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 15000000
time_horizon: 1000

1
config/sac/WalkerStatic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 20000000
time_horizon: 1000

2
config/sac/WallJump.yaml


extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 15000000
time_horizon: 128

extrinsic:
gamma: 0.99
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 5000000
time_horizon: 128

1
config/sac/WormDynamic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 5000000
time_horizon: 1000

1
config/sac/WormStatic.yaml


extrinsic:
gamma: 0.995
strength: 1.0
output_path: default
keep_checkpoints: 5
max_steps: 3000000
time_horizon: 1000

6
ml-agents/mlagents/trainers/ghost/trainer.py


reward_buff_cap,
trainer_settings,
training,
run_id,
artifact_path,
):
"""
Creates a GhostTrainer.

:param reward_buff_cap: Max reward history to track in the reward buffer
:param trainer_settings: The parameters for the trainer.
:param training: Whether the trainer is set for training.
:param run_id: The identifier of the current run
:param artifact_path: Path to store artifacts from this trainer.
brain_name, trainer_settings, training, run_id, reward_buff_cap
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
)
self.trainer = trainer

1
ml-agents/mlagents/trainers/learn.py


)
trainer_factory = TrainerFactory(
options.behaviors,
checkpoint_settings.run_id,
write_path,
not checkpoint_settings.inference,
checkpoint_settings.resume,

4
ml-agents/mlagents/trainers/policy/nn_policy.py


brain: BrainParameters,
trainer_params: TrainerSettings,
is_training: bool,
model_path: str,
load: bool,
tanh_squash: bool = False,
reparameterize: bool = False,

:param trainer_params: Defined training parameters.
:param is_training: Whether the model should be trained.
:param load: Whether a pre-trained model will be loaded or a new one created.
:param model_path: Path where the model should be saved and loaded.
super().__init__(seed, brain, trainer_params, load)
super().__init__(seed, brain, trainer_params, model_path, load)
self.grads = None
self.update_batch: Optional[tf.Operation] = None
num_layers = self.network_settings.num_layers

5
ml-agents/mlagents/trainers/policy/tf_policy.py


seed: int,
brain: BrainParameters,
trainer_settings: TrainerSettings,
model_path: str,
load: bool = False,
):
"""

:param trainer_settings: The trainer parameters.
:param model_path: Where to load/save the model.
:param load: If True, load model from model_path. Otherwise, create new model.
"""
self.m_size = 0

self.use_continuous_act = brain.vector_action_space_type == "continuous"
if self.use_continuous_act:
self.num_branches = self.brain.vector_action_space_size[0]
self.model_path = self.trainer_settings.output_path
self.model_path = model_path
self.initialize_path = self.trainer_settings.init_path
self.keep_checkpoints = self.trainer_settings.keep_checkpoints
self.graph = tf.Graph()

7
ml-agents/mlagents/trainers/ppo/trainer.py


training: bool,
load: bool,
seed: int,
run_id: str,
artifact_path: str,
):
"""
Responsible for collecting experiences and training PPO model.

:param training: Whether the trainer is set for training.
:param load: Whether the model should be loaded.
:param seed: The seed the model will be initialized with
:param run_id: The identifier of the current run
:param artifact_path: The directory within which to store artifacts from this trainer.
brain_name, trainer_settings, training, run_id, reward_buff_cap
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
)
self.hyperparameters: PPOSettings = cast(
PPOSettings, self.trainer_settings.hyperparameters

brain_parameters,
self.trainer_settings,
self.is_training,
self.artifact_path,
self.load,
condition_sigma_on_obs=False, # Faster training for PPO
create_tf_graph=False, # We will create the TF graph in the Optimizer

15
ml-agents/mlagents/trainers/sac/trainer.py


training: bool,
load: bool,
seed: int,
run_id: str,
artifact_path: str,
):
"""
Responsible for collecting experiences and training SAC model.

:param training: Whether the trainer is set for training.
:param load: Whether the model should be loaded.
:param seed: The seed the model will be initialized with
:param run_id: The identifier of the current run
:param artifact_path: The directory within which to store artifacts from this trainer.
brain_name, trainer_settings, training, run_id, reward_buff_cap
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
)
self.load = load

"""
Save the training buffer's update buffer to a pickle file.
"""
filename = os.path.join(
self.trainer_settings.output_path, "last_replay_buffer.hdf5"
)
filename = os.path.join(self.artifact_path, "last_replay_buffer.hdf5")
logger.info("Saving Experience Replay Buffer to {}".format(filename))
with open(filename, "wb") as file_object:
self.update_buffer.save_to_file(file_object)

Loads the last saved replay buffer from a file.
"""
filename = os.path.join(
self.trainer_settings.output_path, "last_replay_buffer.hdf5"
)
filename = os.path.join(self.artifact_path, "last_replay_buffer.hdf5")
logger.info("Loading Experience Replay Buffer from {}".format(filename))
with open(filename, "rb+") as file_object:
self.update_buffer.load_from_file(file_object)

brain_parameters,
self.trainer_settings,
self.is_training,
self.artifact_path,
self.load,
tanh_squash=True,
reparameterize=True,

1
ml-agents/mlagents/trainers/settings.py


factory=lambda: {RewardSignalType.EXTRINSIC: RewardSignalSettings()}
)
init_path: Optional[str] = None
output_path: str = "default"
keep_checkpoints: int = 5
checkpoint_interval: int = 500000
max_steps: int = 500000

8
ml-agents/mlagents/trainers/tests/test_barracuda_converter.py


@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"])
def test_policy_conversion(tmpdir, rnn, visual, discrete):
tf.reset_default_graph()
dummy_config = TrainerSettings(output_path=os.path.join(tmpdir, "test"))
dummy_config = TrainerSettings()
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
dummy_config,
use_rnn=rnn,
model_path=os.path.join(tmpdir, "test"),
use_discrete=discrete,
use_visual=visual,
)
policy.save_model(1000)
settings = SerializationSettings(

2
ml-agents/mlagents/trainers/tests/test_bcmodule.py


NetworkSettings.MemorySettings() if use_rnn else None
)
policy = NNPolicy(
0, mock_brain, trainer_config, False, False, tanhresample, tanhresample
0, mock_brain, trainer_config, False, "test", False, tanhresample, tanhresample
)
with policy.graph.as_default():
bc_module = BCModule(

16
ml-agents/mlagents/trainers/tests/test_nn_policy.py


use_rnn: bool = False,
use_discrete: bool = True,
use_visual: bool = False,
model_path: str = "",
load: bool = False,
seed: int = 0,
) -> NNPolicy:

trainer_settings.network_settings.memory = (
NetworkSettings.MemorySettings() if use_rnn else None
)
policy = NNPolicy(seed, mock_brain, trainer_settings, False, load)
policy = NNPolicy(seed, mock_brain, trainer_settings, False, model_path, load)
return policy

trainer_params = TrainerSettings(output_path=path1)
policy = create_policy_mock(trainer_params)
trainer_params = TrainerSettings()
policy = create_policy_mock(trainer_params, model_path=path1)
policy.initialize_or_load()
policy._set_step(2000)
policy.save_model(2000)

# Try load from this path
policy2 = create_policy_mock(trainer_params, load=True, seed=1)
policy2 = create_policy_mock(trainer_params, model_path=path1, load=True, seed=1)
policy2.initialize_or_load()
_compare_two_policies(policy, policy2)
assert policy2.get_current_step() == 2000

trainer_params.init_path = path1
policy3 = create_policy_mock(trainer_params, load=False, seed=2)
policy3 = create_policy_mock(trainer_params, model_path=path1, load=False, seed=2)
policy3.initialize_or_load()
_compare_two_policies(policy2, policy3)

# Test write_stats
with self.assertLogs("mlagents.trainers", level="WARNING") as cm:
path1 = tempfile.mkdtemp()
trainer_params = TrainerSettings(output_path=path1)
policy = create_policy_mock(trainer_params)
trainer_params = TrainerSettings()
policy = create_policy_mock(trainer_params, model_path=path1)
policy.initialize_or_load()
policy._check_model_version(
"0.0.0"

brain_params,
TrainerSettings(network_settings=NetworkSettings(normalize=True)),
False,
"testdir",
False,
)

6
ml-agents/mlagents/trainers/tests/test_policy.py


def test_take_action_returns_empty_with_no_agents():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings())
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
# Doesn't really matter what this is
dummy_groupspec = BehaviorSpec([(1,)], "continuous", 1)
no_agent_step = DecisionSteps.empty(dummy_groupspec)

def test_take_action_returns_nones_on_missing_values():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings())
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
policy.evaluate = MagicMock(return_value={})
policy.save_memories = MagicMock()
step_with_agents = DecisionSteps(

def test_take_action_returns_action_info_when_available():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings())
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
policy_eval_out = {
"action": np.array([1.0], dtype=np.float32),
"memory_out": np.array([[2.5]], dtype=np.float32),

2
ml-agents/mlagents/trainers/tests/test_ppo.py


else None
)
policy = NNPolicy(
0, mock_brain, trainer_settings, False, False, create_tf_graph=False
0, mock_brain, trainer_settings, False, "test", False, create_tf_graph=False
)
optimizer = PPOOptimizer(policy, trainer_settings)
return optimizer

2
ml-agents/mlagents/trainers/tests/test_reward_signals.py


else None
)
policy = NNPolicy(
0, mock_brain, trainer_settings, False, False, create_tf_graph=False
0, mock_brain, trainer_settings, False, "test", False, create_tf_graph=False
)
if trainer_settings.trainer_type == TrainerType.SAC:
optimizer = SACOptimizer(policy, trainer_settings)

10
ml-agents/mlagents/trainers/tests/test_sac.py


else None
)
policy = NNPolicy(
0, mock_brain, trainer_settings, False, False, create_tf_graph=False
0, mock_brain, trainer_settings, False, "test", False, create_tf_graph=False
)
optimizer = SACOptimizer(policy, trainer_settings)
return optimizer

)
trainer_params = dummy_config
trainer_params.hyperparameters.save_replay_buffer = True
trainer = SACTrainer(mock_brain.brain_name, 1, trainer_params, True, False, 0, 0)
trainer = SACTrainer(
mock_brain.brain_name, 1, trainer_params, True, False, 0, "testdir"
)
policy = trainer.create_policy(mock_brain.brain_name, mock_brain)
trainer.add_policy(mock_brain.brain_name, policy)

# Wipe Trainer and try to load
trainer2 = SACTrainer(mock_brain.brain_name, 1, trainer_params, True, True, 0, 0)
trainer2 = SACTrainer(
mock_brain.brain_name, 1, trainer_params, True, True, 0, "testdir"
)
policy = trainer2.create_policy(mock_brain.brain_name, mock_brain)
trainer2.add_policy(mock_brain.brain_name, policy)

1
ml-agents/mlagents/trainers/tests/test_simple_rl.py


env_manager = SimpleEnvManager(env, EnvironmentParametersChannel())
trainer_factory = TrainerFactory(
trainer_config=trainer_config,
run_id=run_id,
output_path=dir,
train_model=True,
load_model=False,

14
ml-agents/mlagents/trainers/tests/test_trainer_util.py


brain_params_mock = BrainParametersMock()
BrainParametersMock.return_value.brain_name = "testbrain"
external_brains = {"testbrain": BrainParametersMock()}
run_id = "testrun"
output_path = "results_dir"
train_model = True
load_model = False

expected_config = PPO_CONFIG
def mock_constructor(
self, brain, reward_buff_cap, trainer_settings, training, load, seed, run_id
self,
brain,
reward_buff_cap,
trainer_settings,
training,
load,
seed,
artifact_path,
):
assert brain == brain_params_mock.brain_name
assert trainer_settings == expected_config

assert seed == seed
assert run_id == run_id
assert artifact_path == os.path.join(output_path, brain_name)
run_id=run_id,
output_path=output_path,
train_model=train_model,
load_model=load_model,

trainer_factory = trainer_util.TrainerFactory(
trainer_config=no_default_config,
run_id="testrun",
output_path="output_path",
train_model=True,
load_model=False,

12
ml-agents/mlagents/trainers/trainer/trainer.py


brain_name: str,
trainer_settings: TrainerSettings,
training: bool,
run_id: str,
artifact_path: str,
:dict trainer_settings: The parameters for the trainer (dictionary).
:bool training: Whether the trainer is set for training.
:str run_id: The identifier of the current run
:int reward_buff_cap:
:param trainer_settings: The parameters for the trainer (dictionary).
:param training: Whether the trainer is set for training.
:param artifact_path: The directory within which to store artifacts from this trainer
:param reward_buff_cap:
self.run_id = run_id
self.trainer_settings = trainer_settings
self._threaded = trainer_settings.threaded
self._stats_reporter = StatsReporter(brain_name)

self.trajectory_queues: List[AgentManagerQueue[Trajectory]] = []
self.step: int = 0
self.artifact_path = artifact_path
self.summary_freq = self.trainer_settings.summary_freq
@property

13
ml-agents/mlagents/trainers/trainer_util.py


def __init__(
self,
trainer_config: Dict[str, TrainerSettings],
run_id: str,
output_path: str,
train_model: bool,
load_model: bool,

multi_gpu: bool = False,
):
self.trainer_config = trainer_config
self.run_id = run_id
self.output_path = output_path
self.init_path = init_path
self.train_model = train_model

return initialize_trainer(
self.trainer_config[brain_name],
brain_name,
self.run_id,
self.output_path,
self.train_model,
self.load_model,

def initialize_trainer(
trainer_settings: TrainerSettings,
brain_name: str,
run_id: str,
output_path: str,
train_model: bool,
load_model: bool,

:param trainer_settings: Original trainer configuration loaded from YAML
:param brain_name: Name of the brain to be associated with trainer
:param run_id: Run ID to associate with this training run
:param output_path: Path to save the model and summary statistics
:param keep_checkpoints: How many model checkpoints to keep
:param train_model: Whether to train the model (vs. run inference)

:param meta_curriculum: Optional meta_curriculum, used to determine a reward buffer length for PPOTrainer
:return:
"""
trainer_settings.output_path = os.path.join(output_path, brain_name)
trainer_artifact_path = os.path.join(output_path, brain_name)
if init_path is not None:
trainer_settings.init_path = os.path.join(init_path, brain_name)

train_model,
load_model,
seed,
run_id,
trainer_artifact_path,
)
elif trainer_type == TrainerType.SAC:
trainer = SACTrainer(

train_model,
load_model,
seed,
run_id,
trainer_artifact_path,
)
else:
raise TrainerConfigError(

min_lesson_length,
trainer_settings,
train_model,
run_id,
trainer_artifact_path,
)
return trainer

正在加载...
取消
保存