|
|
|
|
|
|
indexed by name. If none, don't update the reward signals. |
|
|
|
:return: Output from update process. |
|
|
|
""" |
|
|
|
t0 = time.time() |
|
|
|
# t0 = time.time() |
|
|
|
rewards = {} |
|
|
|
for name in self.reward_signals: |
|
|
|
rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) |
|
|
|
|
|
|
self.target_network.network_body.copy_normalization( |
|
|
|
self.policy.actor_critic.network_body |
|
|
|
) |
|
|
|
t1 = time.time() |
|
|
|
# t1 = time.time() |
|
|
|
( |
|
|
|
sampled_actions, |
|
|
|
log_probs, |
|
|
|
|
|
|
seq_len=self.policy.sequence_length, |
|
|
|
all_log_probs=not self.policy.use_continuous_act, |
|
|
|
) |
|
|
|
t2 = time.time() |
|
|
|
# t2 = time.time() |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
squeezed_actions = actions.squeeze(-1) |
|
|
|
# q1p_out, q2p_out = self.value_network( |
|
|
|
|
|
|
) |
|
|
|
q1_stream = self._condense_q_streams(q1_out, actions) |
|
|
|
q2_stream = self._condense_q_streams(q2_out, actions) |
|
|
|
t3 = time.time() |
|
|
|
# t3 = time.time() |
|
|
|
with torch.no_grad(): |
|
|
|
target_values, _ = self.target_network( |
|
|
|
next_vec_obs, |
|
|
|
|
|
|
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|
|
|
use_discrete = not self.policy.use_continuous_act |
|
|
|
dones = ModelUtils.list_to_tensor(batch["done"]) |
|
|
|
t4 = time.time() |
|
|
|
# t4 = time.time() |
|
|
|
q1_loss, q2_loss = self.sac_q_loss( |
|
|
|
q1_stream, q2_stream, target_values, dones, rewards, masks |
|
|
|
) |
|
|
|
|
|
|
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete) |
|
|
|
|
|
|
|
total_value_loss = q1_loss + q2_loss + value_loss |
|
|
|
t5 = time.time() |
|
|
|
# t5 = time.time() |
|
|
|
t6 = time.time() |
|
|
|
# t6 = time.time() |
|
|
|
t7 = time.time() |
|
|
|
# t7 = time.time() |
|
|
|
t8 = time.time() |
|
|
|
# t8 = time.time() |
|
|
|
# Update target network |
|
|
|
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau) |
|
|
|
update_stats = { |
|
|
|
|
|
|
"Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(), |
|
|
|
"Policy/Learning Rate": decay_lr, |
|
|
|
} |
|
|
|
t9 = time.time() |
|
|
|
print( |
|
|
|
t9 - t8, |
|
|
|
t8 - t7, |
|
|
|
t7 - t6, |
|
|
|
t6 - t5, |
|
|
|
t5 - t4, |
|
|
|
t4 - t3, |
|
|
|
t3 - t2, |
|
|
|
t2 - t1, |
|
|
|
t1 - t0, |
|
|
|
) |
|
|
|
# t9 = time.time() |
|
|
|
# print( |
|
|
|
# t9 - t8, |
|
|
|
# t8 - t7, |
|
|
|
# t7 - t6, |
|
|
|
# t6 - t5, |
|
|
|
# t5 - t4, |
|
|
|
# t4 - t3, |
|
|
|
# t3 - t2, |
|
|
|
# t2 - t1, |
|
|
|
# t1 - t0, |
|
|
|
# ) |
|
|
|
return update_stats |
|
|
|
|
|
|
|
def update_reward_signals( |
|
|
|