浏览代码

Optimized SAC soft update

/develop/sac-targetq
Ervin Teng 4 年前
当前提交
9088c07a
共有 1 个文件被更改,包括 11 次插入4 次删除
  1. 15
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

15
ml-agents/mlagents/trainers/sac/optimizer_torch.py


return q1_loss, q2_loss
def soft_update(self, source: nn.Module, target: nn.Module, tau: float) -> None:
for source_param, target_param in zip(source.parameters(), target.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - tau) + source_param.data * tau
)
with torch.no_grad():
for source_param, target_param in zip(
source.parameters(), target.parameters()
):
target_param.data.mul_(1.0 - tau)
torch.add(
target_param.data,
source_param.data,
alpha=tau,
out=target_param.data,
)
def sac_value_loss(
self,

正在加载...
取消
保存