torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage
)
policy_loss = -1 * ModelUtils.masked_mean(
torch.min(p_opt_a, p_opt_b).flatten(), loss_masks
torch.min(p_opt_a, p_opt_b), loss_masks
return policy_loss