|
|
|
|
|
|
old_log_probs = ActionLogProbs.from_buffer(batch).flatten() |
|
|
|
log_probs = log_probs.flatten() |
|
|
|
loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) |
|
|
|
seq_len = self.policy.sequence_length |
|
|
|
if seq_len > 1: |
|
|
|
# Do burn-in |
|
|
|
_burn_in_percent = 0.2 |
|
|
|
burn_in_mask = torch.ones_like(loss_masks) |
|
|
|
burn_in_amt = int(seq_len * _burn_in_percent) |
|
|
|
for i in range(batch.num_experiences // seq_len): |
|
|
|
burn_in_mask[seq_len * i : seq_len * i + burn_in_amt] = 0.0 |
|
|
|
loss_masks = loss_masks * burn_in_mask |
|
|
|
value_loss = ModelUtils.trust_region_value_loss( |
|
|
|
values, old_values, returns, decay_eps, loss_masks |
|
|
|
) |
|
|
|