|
|
|
|
|
|
from mlagents.trainers.trainer.rl_trainer import RLTrainer |
|
|
|
from mlagents.trainers.brain import BrainParameters |
|
|
|
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|
|
|
from mlagents.trainers.ppo.optimizer_torch import PPOOptimizer |
|
|
|
from mlagents.trainers.policy.nn_policy import NNPolicy |
|
|
|
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer |
|
|
|
from mlagents.trainers.ppo.optimizer_tf import TFPPOOptimizer |
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
self._check_param_keys() |
|
|
|
self.load = load |
|
|
|
self.seed = seed |
|
|
|
self.policy: TorchPolicy = None # type: ignore |
|
|
|
self.framework = "torch" |
|
|
|
self.policy: Policy = None # type: ignore |
|
|
|
|
|
|
|
def _check_param_keys(self): |
|
|
|
super()._check_param_keys() |
|
|
|
|
|
|
|
|
|
|
def create_policy( |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters |
|
|
|
) -> Policy: |
|
|
|
if self.framework == "torch": |
|
|
|
return self.create_torch_policy(parsed_behavior_id, brain_parameters) |
|
|
|
else: |
|
|
|
return self.create_tf_policy(parsed_behavior_id, brain_parameters) |
|
|
|
|
|
|
|
def create_tf_policy( |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters |
|
|
|
) -> NNPolicy: |
|
|
|
""" |
|
|
|
Creates a PPO policy to trainers list of policies. |
|
|
|
:param parsed_behavior_id: |
|
|
|
:param brain_parameters: specifications for policy construction |
|
|
|
:return policy |
|
|
|
""" |
|
|
|
policy = NNPolicy( |
|
|
|
self.seed, |
|
|
|
brain_parameters, |
|
|
|
self.trainer_parameters, |
|
|
|
self.is_training, |
|
|
|
self.load, |
|
|
|
condition_sigma_on_obs=False, # Faster training for PPO |
|
|
|
) |
|
|
|
return policy |
|
|
|
|
|
|
|
def create_torch_policy( |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters |
|
|
|
) -> TorchPolicy: |
|
|
|
""" |
|
|
|
Creates a PPO policy to trainers list of policies. |
|
|
|
|
|
|
self.load, |
|
|
|
condition_sigma_on_obs=False, # Faster training for PPO |
|
|
|
) |
|
|
|
|
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, policy: TorchPolicy |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Adds policy to trainer. |
|
|
|
|
|
|
if not isinstance(policy, Policy): |
|
|
|
raise RuntimeError("Non-NNPolicy passed to PPOTrainer.add_policy()") |
|
|
|
self.policy = policy |
|
|
|
self.optimizer = PPOOptimizer(self.policy, self.trainer_parameters) |
|
|
|
if self.framework == "torch": |
|
|
|
self.optimizer = TorchPPOOptimizer( # type: ignore |
|
|
|
self.policy, self.trainer_parameters # type: ignore |
|
|
|
) # type: ignore |
|
|
|
else: |
|
|
|
self.optimizer = TFPPOOptimizer( # type: ignore |
|
|
|
self.policy, self.trainer_parameters # type: ignore |
|
|
|
) # type: ignore |
|
|
|
for _reward_signal in self.optimizer.reward_signals.keys(): |
|
|
|
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) |
|
|
|
# Needed to resume loads properly |
|
|
|
|
|
|
def get_policy(self, name_behavior_id: str) -> TorchPolicy: |
|
|
|
def get_policy(self, name_behavior_id: str) -> Policy: |
|
|
|
""" |
|
|
|
Gets policy from trainer associated with name_behavior_id |
|
|
|
:param name_behavior_id: full identifier of policy |
|
|
|