浏览代码

remove target update

/develop/coma-withq
Andrew Cohen 4 年前
当前提交
e3239529
共有 1 个文件被更改,包括 7 次插入7 次删除
  1. 14
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py

14
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


self.stream_names = list(self.reward_signals.keys())
ModelUtils.soft_update(
self.policy.actor_critic.critic, self.policy.actor_critic.target, 1.0
)
# ModelUtils.soft_update(
# self.policy.actor_critic.critic, self.policy.actor_critic.target, 1.0
# )
def ppo_value_loss(
self,

old_log_probs = ActionLogProbs.from_dict(batch).flatten()
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
#q_loss = self.ppo_value_loss(qs, old_values, returns_q, decay_eps, loss_masks)
# q_loss = self.ppo_value_loss(qs, old_values, returns_q, decay_eps, loss_masks)
baseline_loss = self.ppo_value_loss(
baseline_vals, old_marg_values, returns_b, decay_eps, loss_masks
)

self.optimizer.step()
#ModelUtils.soft_update(
# ModelUtils.soft_update(
#)
# )
#"Losses/Q Loss": q_loss.item(),
# "Losses/Q Loss": q_loss.item(),
"Losses/Baseline Value Loss": baseline_loss.item(),
"Policy/Learning Rate": decay_lr,
"Policy/Epsilon": decay_eps,

正在加载...
取消
保存