浏览代码

Do not keep gradients on the q for the v backup (#4504)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
1f179527
共有 1 个文件被更改,包括 30 次插入24 次删除
  1. 54
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

54
ml-agents/mlagents/trainers/sac/optimizer_torch.py


min_policy_qs = {}
with torch.no_grad():
_ent_coef = torch.exp(self._log_ent_coef)
for name in values.keys():
if not discrete:
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name])
else:
action_probs = log_probs.exp()
_branched_q1p = ModelUtils.break_into_branches(
q1p_out[name] * action_probs, self.act_size
)
_branched_q2p = ModelUtils.break_into_branches(
q2p_out[name] * action_probs, self.act_size
)
_q1p_mean = torch.mean(
torch.stack(
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p]
),
dim=0,
)
_q2p_mean = torch.mean(
torch.stack(
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p]
),
dim=0,
)
for name in values.keys():
if not discrete:
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name])
else:
action_probs = log_probs.exp()
_branched_q1p = ModelUtils.break_into_branches(
q1p_out[name] * action_probs, self.act_size
)
_branched_q2p = ModelUtils.break_into_branches(
q2p_out[name] * action_probs, self.act_size
)
_q1p_mean = torch.mean(
torch.stack(
[
torch.sum(_br, dim=1, keepdim=True)
for _br in _branched_q1p
]
),
dim=0,
)
_q2p_mean = torch.mean(
torch.stack(
[
torch.sum(_br, dim=1, keepdim=True)
for _br in _branched_q2p
]
),
dim=0,
)
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)
value_losses = []
if not discrete:

正在加载...
取消
保存