|
|
|
|
|
|
self.act_size, |
|
|
|
) |
|
|
|
|
|
|
|
self.target_network = ValueNetwork( |
|
|
|
self.target_network = TorchSACOptimizer.PolicyValueNetwork( |
|
|
|
self.policy.behavior_spec.action_type, |
|
|
|
self.act_size, |
|
|
|
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0) |
|
|
|
|
|
|
|
self.soft_update(self.target_network, self.target_network, 1.0) |
|
|
|
|
|
|
|
self._log_ent_coef = torch.nn.Parameter( |
|
|
|
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))), |
|
|
|
|
|
|
policy_params = list(self.policy.actor_critic.network_body.parameters()) + list( |
|
|
|
self.policy.actor_critic.distribution.parameters() |
|
|
|
) |
|
|
|
value_params = list(self.value_network.parameters()) + list( |
|
|
|
self.policy.actor_critic.critic.parameters() |
|
|
|
) |
|
|
|
value_params = list(self.value_network.parameters()) |
|
|
|
|
|
|
|
logger.debug("value_vars") |
|
|
|
for param in value_params: |
|
|
|
|
|
|
self, |
|
|
|
q1_out: Dict[str, torch.Tensor], |
|
|
|
q2_out: Dict[str, torch.Tensor], |
|
|
|
target_values: Dict[str, torch.Tensor], |
|
|
|
target_q1: Dict[str, torch.Tensor], |
|
|
|
target_q2: Dict[str, torch.Tensor], |
|
|
|
dones: torch.Tensor, |
|
|
|
rewards: Dict[str, torch.Tensor], |
|
|
|
loss_masks: torch.Tensor, |
|
|
|
|
|
|
q1_stream = q1_out[name].squeeze() |
|
|
|
q2_stream = q2_out[name].squeeze() |
|
|
|
with torch.no_grad(): |
|
|
|
target_q = torch.min(target_q1[name].squeeze(), target_q2[name].squeeze()) |
|
|
|
* target_values[name] |
|
|
|
* target_q |
|
|
|
) |
|
|
|
_q1_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks |
|
|
|
|
|
|
self.value_network.q2_network.network_body.copy_normalization( |
|
|
|
self.policy.actor_critic.network_body |
|
|
|
) |
|
|
|
self.target_network.network_body.copy_normalization( |
|
|
|
self.target_network.q1_network.network_body.copy_normalization( |
|
|
|
self.policy.actor_critic.network_body |
|
|
|
) |
|
|
|
self.target_network.q2_network.network_body.copy_normalization( |
|
|
|
entropies, |
|
|
|
sampled_values, |
|
|
|
_, |
|
|
|
) = self.policy.sample_actions( |
|
|
|
) = self._sample_actions( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
masks=act_masks, |
|
|
|
|
|
|
q2_stream = self._condense_q_streams(q2_out, actions) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
target_values, _ = self.target_network( |
|
|
|
( |
|
|
|
next_actions, |
|
|
|
_, |
|
|
|
_, |
|
|
|
_, |
|
|
|
) = self._sample_actions( |
|
|
|
memories=next_memories, |
|
|
|
masks=act_masks, |
|
|
|
memories=memories, |
|
|
|
seq_len=self.policy.sequence_length, |
|
|
|
all_log_probs=not self.policy.use_continuous_act, |
|
|
|
) |
|
|
|
q1_target, q2_target = self.value_network( |
|
|
|
next_vec_obs, |
|
|
|
next_vis_obs, |
|
|
|
next_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|
|
|
|
|
|
q1_loss, q2_loss = self.sac_q_loss( |
|
|
|
q1_stream, q2_stream, target_values, dones, rewards, masks |
|
|
|
q1_stream, q2_stream, q1_target, q2_target, dones, rewards, masks |
|
|
|
value_loss = self.sac_value_loss( |
|
|
|
log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete |
|
|
|
) |
|
|
|
# value_loss = self.sac_value_loss( |
|
|
|
# log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete |
|
|
|
# ) |
|
|
|
total_value_loss = q1_loss + q2_loss + value_loss |
|
|
|
total_value_loss = q1_loss + q2_loss |
|
|
|
|
|
|
|
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|
|
|
ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) |
|
|
|
|
|
|
self.entropy_optimizer.step() |
|
|
|
|
|
|
|
# Update target network |
|
|
|
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau) |
|
|
|
self.soft_update(self.value_network, self.target_network, self.tau) |
|
|
|
"Losses/Value Loss": value_loss.item(), |
|
|
|
"Losses/Value Loss": total_value_loss.item(), |
|
|
|
"Losses/Q1 Loss": q1_loss.item(), |
|
|
|
"Losses/Q2 Loss": q2_loss.item(), |
|
|
|
"Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(), |
|
|
|
|
|
|
for reward_provider in self.reward_signals.values(): |
|
|
|
modules.update(reward_provider.get_modules()) |
|
|
|
return modules |
|
|
|
|
|
|
|
def _sample_actions( |
|
|
|
self, |
|
|
|
vec_obs: List[torch.Tensor], |
|
|
|
vis_obs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
seq_len: int = 1, |
|
|
|
all_log_probs: bool = False, |
|
|
|
) -> Tuple[ |
|
|
|
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|
|
|
]: |
|
|
|
""" |
|
|
|
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action. |
|
|
|
""" |
|
|
|
dists, memories = self.policy.actor_critic.get_dists( |
|
|
|
vec_obs, vis_obs, masks, memories, seq_len |
|
|
|
) |
|
|
|
action_list = self.policy.actor_critic.sample_action(dists) |
|
|
|
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy( |
|
|
|
action_list, dists |
|
|
|
) |
|
|
|
actions = torch.stack(action_list, dim=-1) |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
actions = actions[:, :, 0] |
|
|
|
else: |
|
|
|
actions = actions[:, 0, :] |
|
|
|
|
|
|
|
return ( |
|
|
|
actions, |
|
|
|
all_logs if all_log_probs else log_probs, |
|
|
|
entropies, |
|
|
|
memories, |
|
|
|
) |