浏览代码

Added the algorithm named ppo_transfer

/develop/model-transfer
yanchaosun 5 年前
当前提交
3ef4196e
共有 6 个文件被更改,包括 47 次插入7 次删除
  1. 4
      ml-agents/mlagents/trainers/ppo_transfer/optimizer.py
  2. 9
      ml-agents/mlagents/trainers/ppo_transfer/trainer.py
  3. 4
      ml-agents/mlagents/trainers/settings.py
  4. 11
      ml-agents/mlagents/trainers/trainer_util.py
  5. 26
      config/ppo_transfer/3DBall.yaml

4
ml-agents/mlagents/trainers/ppo_transfer/optimizer.py


from mlagents.trainers.settings import TrainerSettings, PPOSettings
class PPOOptimizer(TFOptimizer):
class PPOTransferOptimizer(TFOptimizer):
The PPO optimizer has a value estimator and a loss function.
The PPO optimizer has a value esåtimator and a loss function.
:param policy: A TFPolicy object that will be updated by this PPO Optimizer.
:param trainer_params: Trainer parameters dictionary that specifies the properties of the trainer.
"""

9
ml-agents/mlagents/trainers/ppo_transfer/trainer.py


from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.ppo_transfer.optimizer import PPOTransferOptimizer
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, PPOSettings

class PPOTrainer(RLTrainer):
class PPOTransferTrainer(RLTrainer):
"""The PPOTrainer is an implementation of the PPO algorithm."""
def __init__(

:param seed: The seed the model will be initialized with
:param artifact_path: The directory within which to store artifacts from this trainer.
"""
super(PPOTrainer, self).__init__(
super(PPOTransferTrainer, self).__init__(
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
)
self.hyperparameters: PPOSettings = cast(

self.seed = seed
self.policy: NNPolicy = None # type: ignore
print("The current algorithm is PPO Transfer")
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""

if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-NNPolicy passed to PPOTrainer.add_policy()")
self.policy = policy
self.optimizer = PPOOptimizer(self.policy, self.trainer_settings)
self.optimizer = PPOTransferOptimizer(self.policy, self.trainer_settings)
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly

4
ml-agents/mlagents/trainers/settings.py


class TrainerType(Enum):
PPO: str = "ppo"
SAC: str = "sac"
PPO_Transfer: str = "ppo_transfer"
_mapping = {TrainerType.PPO: PPOSettings, TrainerType.SAC: SACSettings}
_mapping = {TrainerType.PPO: PPOSettings, TrainerType.SAC: SACSettings,
TrainerType.PPO_Transfer: PPOSettings}
return _mapping[self]

11
ml-agents/mlagents/trainers/trainer_util.py


from mlagents.trainers.trainer import Trainer
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.trainers.ppo_transfer.trainer import PPOTransferTrainer
from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController

)
elif trainer_type == TrainerType.SAC:
trainer = SACTrainer(
brain_name,
min_lesson_length,
trainer_settings,
train_model,
load_model,
seed,
trainer_artifact_path,
)
elif trainer_type == TrainerType.PPO_Transfer:
trainer = PPOTransferTrainer(
brain_name,
min_lesson_length,
trainer_settings,

26
config/ppo_transfer/3DBall.yaml


behaviors:
3DBall:
trainer_type: ppo_transfer
hyperparameters:
batch_size: 64
buffer_size: 12000
learning_rate: 0.0003
beta: 0.001
epsilon: 0.2
lambd: 0.99
num_epoch: 3
learning_rate_schedule: linear
network_settings:
normalize: true
hidden_units: 128
num_layers: 2
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
keep_checkpoints: 5
max_steps: 500000
time_horizon: 1000
summary_freq: 12000
threaded: true
正在加载...
取消
保存