浏览代码

add critic to optimizer, ppo runs

/develop/action-slice
Andrew Cohen 4 年前
当前提交
6bd396ee
共有 4 个文件被更改,包括 20 次插入6 次删除
  1. 4
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 18
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  3. 2
      ml-agents/mlagents/trainers/torch/components/bc/module.py
  4. 2
      ml-agents/mlagents/trainers/torch/model_serialization.py

4
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


next_obs = [obs.unsqueeze(0) for obs in next_obs]
value_estimates, next_memory = self.policy.actor_critic.critic_pass(
value_estimates, next_memory = self.critic.critic_pass(
next_value_estimate, _ = self.policy.actor_critic.critic_pass(
next_value_estimate, _ = self.critic.critic_pass(
next_obs, next_memory, sequence_length=1
)

18
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.settings import TrainerSettings, PPOSettings
from mlagents.trainers.torch.networks import ValueNetwork
from mlagents.trainers.torch.agent_action import AgentAction
from mlagents.trainers.torch.action_log_probs import ActionLogProbs
from mlagents.trainers.torch.utils import ModelUtils

# Create the graph here to give more granular control of the TF graph to the Optimizer.
super().__init__(policy, trainer_settings)
params = list(self.policy.actor_critic.parameters())
reward_signal_configs = trainer_settings.reward_signals
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]
self.critic = ValueNetwork(
reward_signal_names,
policy.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
)
params = list(self.policy.actor.parameters()) + list(self.critic.parameters())
self.hyperparameters: PPOSettings = cast(
PPOSettings, trainer_settings.hyperparameters
)

if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
log_probs, entropy, values = self.policy.evaluate_actions(
log_probs, entropy = self.policy.evaluate_actions(
)
values, _ = self.critic.critic_pass(
current_obs, memories=memories, sequence_length=self.policy.sequence_length
)
old_log_probs = ActionLogProbs.from_buffer(batch).flatten()
log_probs = log_probs.flatten()

2
ml-agents/mlagents/trainers/torch/components/bc/module.py


self.decay_learning_rate = ModelUtils.DecayedValue(
learning_rate_schedule, self.current_lr, 1e-10, self._anneal_steps
)
params = self.policy.actor_critic.parameters()
params = self.policy.actor.parameters()
self.optimizer = torch.optim.Adam(params, lr=self.current_lr)
_, self.demonstration_buffer = demo_to_buffer(
settings.demo_path, policy.sequence_length, policy.behavior_spec

2
ml-agents/mlagents/trainers/torch/model_serialization.py


with exporting_to_onnx():
torch.onnx.export(
self.policy.actor_critic,
self.policy.actor,
self.dummy_input,
onnx_output_path,
opset_version=SerializationSettings.onnx_opset,

正在加载...
取消
保存