浏览代码

[bug-fix] Use proper masking for entropy and policy losses (#4572)

* Use proper masking for entropy and policy losses

* Fix dimension
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
dde34423
共有 1 个文件被更改,包括 5 次插入4 次删除
  1. 9
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

9
ml-agents/mlagents/trainers/sac/optimizer_torch.py


for i, (_lp, _qt) in enumerate(
zip(branched_per_action_ent, branched_q_term)
)
]
],
dim=1,
policy_loss = torch.mean(loss_masks * batch_policy_loss)
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
return policy_loss
def sac_entropy_loss(

with torch.no_grad():
target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1)
entropy_loss = -torch.mean(
self._log_ent_coef * loss_masks * target_current_diff
entropy_loss = -1 * ModelUtils.masked_mean(
self._log_ent_coef * target_current_diff, loss_masks
)
else:
with torch.no_grad():

正在加载...
取消
保存