|
|
|
|
|
|
|
|
|
|
self.stream_names = list(self.reward_signals.keys()) |
|
|
|
|
|
|
|
ModelUtils.soft_update( |
|
|
|
self.policy.actor_critic.critic, self.policy.actor_critic.target, 1.0 |
|
|
|
) |
|
|
|
# ModelUtils.soft_update( |
|
|
|
# self.policy.actor_critic.critic, self.policy.actor_critic.target, 1.0 |
|
|
|
# ) |
|
|
|
|
|
|
|
def ppo_value_loss( |
|
|
|
self, |
|
|
|
|
|
|
old_log_probs = ActionLogProbs.from_dict(batch).flatten() |
|
|
|
log_probs = log_probs.flatten() |
|
|
|
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|
|
|
#q_loss = self.ppo_value_loss(qs, old_values, returns_q, decay_eps, loss_masks) |
|
|
|
# q_loss = self.ppo_value_loss(qs, old_values, returns_q, decay_eps, loss_masks) |
|
|
|
baseline_loss = self.ppo_value_loss( |
|
|
|
baseline_vals, old_marg_values, returns_b, decay_eps, loss_masks |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
#ModelUtils.soft_update( |
|
|
|
# ModelUtils.soft_update( |
|
|
|
#) |
|
|
|
# ) |
|
|
|
#"Losses/Q Loss": q_loss.item(), |
|
|
|
# "Losses/Q Loss": q_loss.item(), |
|
|
|
"Losses/Baseline Value Loss": baseline_loss.item(), |
|
|
|
"Policy/Learning Rate": decay_lr, |
|
|
|
"Policy/Epsilon": decay_eps, |
|
|
|