|
|
|
|
|
|
ModelUtils.list_to_tensor(batch["action_probs"]), |
|
|
|
loss_masks, |
|
|
|
) |
|
|
|
# Use the sum of entropy across actions, not the mean |
|
|
|
entropy_sum = torch.sum(entropy, dim=1) |
|
|
|
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks) |
|
|
|
- decay_bet * ModelUtils.masked_mean(entropy_sum, loss_masks) |
|
|
|
) |
|
|
|
|
|
|
|
# Set optimizer learning rate |
|
|
|