|
|
|
|
|
|
self.policy.actor_critic.critic, self.policy.actor_critic.target, 1.0 |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def ppo_value_loss( |
|
|
|
self, |
|
|
|
values: Dict[str, torch.Tensor], |
|
|
|
|
|
|
""" |
|
|
|
value_losses = [] |
|
|
|
for name, head in values.items(): |
|
|
|
#old_val_tensor = old_values[name] |
|
|
|
returns_tensor = returns[name]# + 0.99 * old_val_tensor |
|
|
|
#clipped_value_estimate = old_val_tensor + torch.clamp( |
|
|
|
# old_val_tensor = old_values[name] |
|
|
|
returns_tensor = returns[name] # + 0.99 * old_val_tensor |
|
|
|
# clipped_value_estimate = old_val_tensor + torch.clamp( |
|
|
|
#) |
|
|
|
#value_loss = (returns_tensor - head) ** 2 |
|
|
|
# ) |
|
|
|
# value_loss = (returns_tensor - head) ** 2 |
|
|
|
#v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 |
|
|
|
#value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) |
|
|
|
# v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 |
|
|
|
# value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) |
|
|
|
value_loss = ModelUtils.masked_mean(v_opt_a, loss_masks) |
|
|
|
value_losses.append(value_loss) |
|
|
|
value_loss = torch.mean(torch.stack(value_losses)) |
|
|
|
|
|
|
returns_q[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns_q"]) |
|
|
|
returns_b[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns_b"]) |
|
|
|
returns_v[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns_v"]) |
|
|
|
# |
|
|
|
# |
|
|
|
|
|
|
|
n_obs = len(self.policy.behavior_spec.sensor_specs) |
|
|
|
current_obs = ObsUtil.from_buffer(batch, n_obs) |
|
|
|
|
|
|
old_log_probs = ActionLogProbs.from_dict(batch).flatten() |
|
|
|
log_probs = log_probs.flatten() |
|
|
|
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|
|
|
q_loss = self.ppo_value_loss( |
|
|
|
qs, old_values, returns_q, decay_eps, loss_masks |
|
|
|
) |
|
|
|
q_loss = self.ppo_value_loss(qs, old_values, returns_q, decay_eps, loss_masks) |
|
|
|
baseline_loss = self.ppo_value_loss( |
|
|
|
baseline_vals, old_marg_values, returns_b, decay_eps, loss_masks |
|
|
|
) |
|
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
ModelUtils.soft_update( |
|
|
|
self.policy.actor_critic.critic, self.policy.actor_critic.target, 1.0 |
|
|
|
self.policy.actor_critic.critic, self.policy.actor_critic.target, 0.005 |
|
|
|
) |
|
|
|
update_stats = { |
|
|
|
# NOTE: abs() is not technically correct, but matches the behavior in TensorFlow. |
|
|
|