浏览代码

fixed discrete loss fn

/develop/add-fire/bc
Andrew Cohen 5 年前
当前提交
0173acb9
共有 1 个文件被更改,包括 6 次插入2 次删除
  1. 8
      ml-agents/mlagents/trainers/components/bc/module_torch.py

8
ml-agents/mlagents/trainers/components/bc/module_torch.py


mini_batch_demo = demo_update_buffer.make_mini_batch(start, end)
run_out = self._update_batch(mini_batch_demo, self.n_sequences)
loss = run_out["loss"]
# TODO: anneal LR
# self.current_lr = update_stats["learning_rate"]
batch_losses.append(loss)
self.has_updated = True

bc_loss = torch.mean(
torch.stack(
[
-torch.nn.functional.log_softmax(log_prob_branch, dim=0)
* expert_actions_branch
torch.sum(
-torch.nn.functional.log_softmax(log_prob_branch, dim=1)
* expert_actions_branch,
dim=1,
)
for log_prob_branch, expert_actions_branch in zip(
log_prob_branches, expert_actions
)

正在加载...
取消
保存