|
|
|
|
|
|
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|
|
|
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|
|
|
from mlagents.trainers.settings import TrainerSettings, PPOSettings |
|
|
|
from mlagents.trainers.torch.networks import ValueNetwork |
|
|
|
from mlagents.trainers.torch.agent_action import AgentAction |
|
|
|
from mlagents.trainers.torch.action_log_probs import ActionLogProbs |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
|
|
|
# Create the graph here to give more granular control of the TF graph to the Optimizer. |
|
|
|
|
|
|
|
super().__init__(policy, trainer_settings) |
|
|
|
params = list(self.policy.actor_critic.parameters()) |
|
|
|
|
|
|
|
reward_signal_configs = trainer_settings.reward_signals |
|
|
|
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|
|
|
|
|
|
|
self.critic = ValueNetwork( |
|
|
|
reward_signal_names, |
|
|
|
policy.behavior_spec.observation_specs, |
|
|
|
network_settings=trainer_settings.network_settings, |
|
|
|
) |
|
|
|
|
|
|
|
params = list(self.policy.actor.parameters()) + list(self.critic.parameters()) |
|
|
|
self.hyperparameters: PPOSettings = cast( |
|
|
|
PPOSettings, trainer_settings.hyperparameters |
|
|
|
) |
|
|
|
|
|
|
if len(memories) > 0: |
|
|
|
memories = torch.stack(memories).unsqueeze(0) |
|
|
|
|
|
|
|
log_probs, entropy, values = self.policy.evaluate_actions( |
|
|
|
log_probs, entropy = self.policy.evaluate_actions( |
|
|
|
) |
|
|
|
values, _ = self.critic.critic_pass( |
|
|
|
current_obs, memories=memories, sequence_length=self.policy.sequence_length |
|
|
|
) |
|
|
|
old_log_probs = ActionLogProbs.from_buffer(batch).flatten() |
|
|
|
log_probs = log_probs.flatten() |
|
|
|