浏览代码

Do burn-in for PPO

/develop/lstm-burnin
Ervin Teng 4 年前
当前提交
a9ca7b3b
共有 1 个文件被更改,包括 9 次插入0 次删除
  1. 9
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py

9
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


old_log_probs = ActionLogProbs.from_buffer(batch).flatten()
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool)
seq_len = self.policy.sequence_length
if seq_len > 1:
# Do burn-in
_burn_in_percent = 0.2
burn_in_mask = torch.ones_like(loss_masks)
burn_in_amt = int(seq_len * _burn_in_percent)
for i in range(batch.num_experiences // seq_len):
burn_in_mask[seq_len * i : seq_len * i + burn_in_amt] = 0.0
loss_masks = loss_masks * burn_in_mask
value_loss = ModelUtils.trust_region_value_loss(
values, old_values, returns, decay_eps, loss_masks
)

正在加载...
取消
保存