浏览代码

add target network back

/develop/coma-noact
Andrew Cohen 4 年前
当前提交
a92baab6
共有 2 个文件被更改,包括 4 次插入3 次删除
  1. 4
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 3
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py

4
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


memory = torch.zeros([1, 1, self.policy.m_size])
value_estimates, marg_val_estimates, mem = self.policy.actor_critic.critic_pass(
value_estimates, marg_val_estimates, mem = self.policy.actor_critic.target_critic_pass(
current_obs,
actions,
memory,

)
next_value_estimates, next_marg_val_estimates, next_mem = self.policy.actor_critic.critic_pass(
next_value_estimates, next_marg_val_estimates, next_mem = self.policy.actor_critic.target_critic_pass(
next_obs,
next_actions,
memory,

3
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


self.optimizer.step()
ModelUtils.soft_update(
self.policy.actor_critic.critic, self.policy.actor_critic.target, 1.
self.policy.actor_critic.critic, self.policy.actor_critic.target, 0.005
)
update_stats = {

"Losses/Value Loss": value_loss.item(),
"Losses/Baseline Value Loss": marg_value_loss.item(),
"Policy/Advantages": torch.mean(ModelUtils.list_to_tensor(batch["advantages"])).item(),
"Policy/Learning Rate": decay_lr,
"Policy/Epsilon": decay_eps,
"Policy/Beta": decay_bet,

正在加载...
取消
保存