|
|
|
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents_envs.timers import timed |
|
|
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
from mlagents.trainers.settings import TrainerSettings, SACSettings |
|
|
|
|
|
|
|
EPSILON = 1e-6 # Small value to avoid divide by zero |
|
|
|
|
|
|
q2_out: Dict[str, torch.Tensor], |
|
|
|
target_q1: Dict[str, torch.Tensor], |
|
|
|
target_q2: Dict[str, torch.Tensor], |
|
|
|
next_log_prob: torch.Tensor, |
|
|
|
discrete: bool = False, |
|
|
|
with torch.no_grad(): |
|
|
|
_ent_coef = torch.exp(self._log_ent_coef) |
|
|
|
target_q = torch.min(target_q1[name].squeeze(), target_q2[name].squeeze()) |
|
|
|
if not discrete: |
|
|
|
ent_bonus = -torch.sum(_ent_coef * next_log_prob, dim=1) |
|
|
|
else: |
|
|
|
branched_per_action_ent = ModelUtils.break_into_branches( |
|
|
|
next_log_prob * next_log_prob.exp(), self.act_size |
|
|
|
) |
|
|
|
# We have to do entropy bonus per action branch |
|
|
|
branched_ent_bonus = torch.stack( |
|
|
|
[ |
|
|
|
torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True) |
|
|
|
for i, _lp in enumerate(branched_per_action_ent) |
|
|
|
] |
|
|
|
) |
|
|
|
ent_bonus = -torch.mean(branched_ent_bonus, axis=0).squeeze() |
|
|
|
target_q = torch.min( |
|
|
|
target_q1[name].squeeze(), target_q2[name].squeeze() |
|
|
|
) |
|
|
|
target_q = target_q + ent_bonus |
|
|
|
q_backup = rewards[name] + ( |
|
|
|
(1.0 - self.use_dones_in_backup[name] * dones) |
|
|
|
* self.gammas[i] |
|
|
|
|
|
|
out=target_param.data, |
|
|
|
) |
|
|
|
|
|
|
|
def sac_value_loss( |
|
|
|
self, |
|
|
|
log_probs: torch.Tensor, |
|
|
|
values: Dict[str, torch.Tensor], |
|
|
|
q1p_out: Dict[str, torch.Tensor], |
|
|
|
q2p_out: Dict[str, torch.Tensor], |
|
|
|
loss_masks: torch.Tensor, |
|
|
|
discrete: bool, |
|
|
|
) -> torch.Tensor: |
|
|
|
min_policy_qs = {} |
|
|
|
with torch.no_grad(): |
|
|
|
_ent_coef = torch.exp(self._log_ent_coef) |
|
|
|
for name in values.keys(): |
|
|
|
if not discrete: |
|
|
|
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) |
|
|
|
else: |
|
|
|
action_probs = log_probs.exp() |
|
|
|
_branched_q1p = ModelUtils.break_into_branches( |
|
|
|
q1p_out[name] * action_probs, self.act_size |
|
|
|
) |
|
|
|
_branched_q2p = ModelUtils.break_into_branches( |
|
|
|
q2p_out[name] * action_probs, self.act_size |
|
|
|
) |
|
|
|
_q1p_mean = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p] |
|
|
|
), |
|
|
|
dim=0, |
|
|
|
) |
|
|
|
_q2p_mean = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p] |
|
|
|
), |
|
|
|
dim=0, |
|
|
|
) |
|
|
|
|
|
|
|
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) |
|
|
|
|
|
|
|
value_losses = [] |
|
|
|
if not discrete: |
|
|
|
for name in values.keys(): |
|
|
|
with torch.no_grad(): |
|
|
|
v_backup = min_policy_qs[name] - torch.sum( |
|
|
|
_ent_coef * log_probs, dim=1 |
|
|
|
) |
|
|
|
value_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks |
|
|
|
) |
|
|
|
value_losses.append(value_loss) |
|
|
|
else: |
|
|
|
branched_per_action_ent = ModelUtils.break_into_branches( |
|
|
|
log_probs * log_probs.exp(), self.act_size |
|
|
|
) |
|
|
|
# We have to do entropy bonus per action branch |
|
|
|
branched_ent_bonus = torch.stack( |
|
|
|
[ |
|
|
|
torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True) |
|
|
|
for i, _lp in enumerate(branched_per_action_ent) |
|
|
|
] |
|
|
|
) |
|
|
|
for name in values.keys(): |
|
|
|
with torch.no_grad(): |
|
|
|
v_backup = min_policy_qs[name] - torch.mean( |
|
|
|
branched_ent_bonus, axis=0 |
|
|
|
) |
|
|
|
value_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(values[name], v_backup.squeeze()), |
|
|
|
loss_masks, |
|
|
|
) |
|
|
|
value_losses.append(value_loss) |
|
|
|
value_loss = torch.mean(torch.stack(value_losses)) |
|
|
|
if torch.isinf(value_loss).any() or torch.isnan(value_loss).any(): |
|
|
|
raise UnityTrainerException("Inf found") |
|
|
|
return value_loss |
|
|
|
|
|
|
|
def sac_policy_loss( |
|
|
|
self, |
|
|
|
log_probs: torch.Tensor, |
|
|
|
|
|
|
self.target_network.q2_network.network_body.copy_normalization( |
|
|
|
self.policy.actor_critic.network_body |
|
|
|
) |
|
|
|
( |
|
|
|
sampled_actions, |
|
|
|
log_probs, |
|
|
|
_, |
|
|
|
_, |
|
|
|
) = self._sample_actions( |
|
|
|
(sampled_actions, log_probs, _, _) = self._sample_actions( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
masks=act_masks, |
|
|
|
|
|
|
q2_stream = self._condense_q_streams(q2_out, actions) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
( |
|
|
|
next_actions, |
|
|
|
_, |
|
|
|
_, |
|
|
|
_, |
|
|
|
) = self._sample_actions( |
|
|
|
(next_actions, next_log_probs, _, _) = self._sample_actions( |
|
|
|
next_vec_obs, |
|
|
|
next_vis_obs, |
|
|
|
masks=act_masks, |
|
|
|
|
|
|
) |
|
|
|
q1_target, q2_target = self.value_network( |
|
|
|
q1_target, q2_target = self.target_network( |
|
|
|
next_actions, |
|
|
|
next_actions if self.policy.use_continuous_act else None, |
|
|
|
if not self.policy.use_continuous_act: |
|
|
|
q1_target = self._condense_q_streams(q1_target, next_actions) |
|
|
|
q2_target = self._condense_q_streams(q2_target, next_actions) |
|
|
|
|
|
|
|
q1_stream, q2_stream, q1_target, q2_target, dones, rewards, masks |
|
|
|
q1_stream, |
|
|
|
q2_stream, |
|
|
|
q1_target, |
|
|
|
q2_target, |
|
|
|
next_log_probs, |
|
|
|
dones, |
|
|
|
rewards, |
|
|
|
masks, |
|
|
|
discrete=not self.policy.use_continuous_act, |
|
|
|
) |
|
|
|
# value_loss = self.sac_value_loss( |
|
|
|
# log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete |
|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
seq_len: int = 1, |
|
|
|
all_log_probs: bool = False, |
|
|
|
) -> Tuple[ |
|
|
|
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|
|
|
]: |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
""" |
|
|
|
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action. |
|
|
|
""" |
|
|
|
|
|
|
else: |
|
|
|
actions = actions[:, 0, :] |
|
|
|
|
|
|
|
return ( |
|
|
|
actions, |
|
|
|
all_logs if all_log_probs else log_probs, |
|
|
|
entropies, |
|
|
|
memories, |
|
|
|
) |
|
|
|
return (actions, all_logs if all_log_probs else log_probs, entropies, memories) |