|
|
|
|
|
|
return q1_loss, q2_loss |
|
|
|
|
|
|
|
def soft_update(self, source: nn.Module, target: nn.Module, tau: float) -> None: |
|
|
|
for source_param, target_param in zip(source.parameters(), target.parameters()): |
|
|
|
target_param.data.copy_( |
|
|
|
target_param.data * (1.0 - tau) + source_param.data * tau |
|
|
|
) |
|
|
|
with torch.no_grad(): |
|
|
|
for source_param, target_param in zip( |
|
|
|
source.parameters(), target.parameters() |
|
|
|
): |
|
|
|
target_param.data.mul_(1.0 - tau) |
|
|
|
torch.add( |
|
|
|
target_param.data, |
|
|
|
source_param.data, |
|
|
|
alpha=tau, |
|
|
|
out=target_param.data, |
|
|
|
) |
|
|
|
|
|
|
|
def sac_value_loss( |
|
|
|
self, |
|
|
|