|
|
|
|
|
|
for name, head in values.items(): |
|
|
|
old_val_tensor = old_values[name] |
|
|
|
returns_tensor = returns[name] |
|
|
|
clipped_value_estimate = old_val_tensor + torch.clamp( |
|
|
|
head - old_val_tensor, -1 * epsilon, epsilon |
|
|
|
) |
|
|
|
# clipped_value_estimate = old_val_tensor + torch.clamp( |
|
|
|
# head - old_val_tensor, -1 * epsilon, epsilon |
|
|
|
# ) |
|
|
|
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 |
|
|
|
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) |
|
|
|
#value_loss = ModelUtils.masked_mean(v_opt_a, loss_masks) |
|
|
|
# v_opt_b = (returns_tensor - clipped_value_estimate) ** 2 |
|
|
|
# value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) |
|
|
|
value_loss = ModelUtils.masked_mean(v_opt_a, loss_masks) |
|
|
|
value_losses.append(value_loss) |
|
|
|
value_loss = torch.mean(torch.stack(value_losses)) |
|
|
|
return value_loss |
|
|
|