|
|
|
|
|
|
mini_batch_demo = demo_update_buffer.make_mini_batch(start, end) |
|
|
|
run_out = self._update_batch(mini_batch_demo, self.n_sequences) |
|
|
|
loss = run_out["loss"] |
|
|
|
# TODO: anneal LR |
|
|
|
# self.current_lr = update_stats["learning_rate"] |
|
|
|
batch_losses.append(loss) |
|
|
|
self.has_updated = True |
|
|
|
|
|
|
bc_loss = torch.mean( |
|
|
|
torch.stack( |
|
|
|
[ |
|
|
|
-torch.nn.functional.log_softmax(log_prob_branch, dim=0) |
|
|
|
* expert_actions_branch |
|
|
|
torch.sum( |
|
|
|
-torch.nn.functional.log_softmax(log_prob_branch, dim=1) |
|
|
|
* expert_actions_branch, |
|
|
|
dim=1, |
|
|
|
) |
|
|
|
for log_prob_branch, expert_actions_branch in zip( |
|
|
|
log_prob_branches, expert_actions |
|
|
|
) |
|
|
|