|
|
|
|
|
|
min_policy_qs = {} |
|
|
|
with torch.no_grad(): |
|
|
|
_ent_coef = torch.exp(self._log_ent_coef) |
|
|
|
for name in values.keys(): |
|
|
|
if not discrete: |
|
|
|
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) |
|
|
|
else: |
|
|
|
action_probs = log_probs.exp() |
|
|
|
_branched_q1p = ModelUtils.break_into_branches( |
|
|
|
q1p_out[name] * action_probs, self.act_size |
|
|
|
) |
|
|
|
_branched_q2p = ModelUtils.break_into_branches( |
|
|
|
q2p_out[name] * action_probs, self.act_size |
|
|
|
) |
|
|
|
_q1p_mean = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p] |
|
|
|
), |
|
|
|
dim=0, |
|
|
|
) |
|
|
|
_q2p_mean = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p] |
|
|
|
), |
|
|
|
dim=0, |
|
|
|
) |
|
|
|
for name in values.keys(): |
|
|
|
if not discrete: |
|
|
|
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) |
|
|
|
else: |
|
|
|
action_probs = log_probs.exp() |
|
|
|
_branched_q1p = ModelUtils.break_into_branches( |
|
|
|
q1p_out[name] * action_probs, self.act_size |
|
|
|
) |
|
|
|
_branched_q2p = ModelUtils.break_into_branches( |
|
|
|
q2p_out[name] * action_probs, self.act_size |
|
|
|
) |
|
|
|
_q1p_mean = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[ |
|
|
|
torch.sum(_br, dim=1, keepdim=True) |
|
|
|
for _br in _branched_q1p |
|
|
|
] |
|
|
|
), |
|
|
|
dim=0, |
|
|
|
) |
|
|
|
_q2p_mean = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[ |
|
|
|
torch.sum(_br, dim=1, keepdim=True) |
|
|
|
for _br in _branched_q2p |
|
|
|
] |
|
|
|
), |
|
|
|
dim=0, |
|
|
|
) |
|
|
|
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) |
|
|
|
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) |
|
|
|
|
|
|
|
value_losses = [] |
|
|
|
if not discrete: |
|
|
|