|
|
|
|
|
|
self.decay_learning_rate = ModelUtils.DecayedValue( |
|
|
|
learning_rate_schedule, self.current_lr, 1e-10, self._anneal_steps |
|
|
|
) |
|
|
|
params = list(self.policy.actor_critic.parameters()) |
|
|
|
params = self.policy.actor_critic.parameters() |
|
|
|
self.optimizer = torch.optim.Adam(params, lr=self.current_lr) |
|
|
|
|
|
|
|
_, self.demonstration_buffer = demo_to_buffer( |
|
|
|
|
|
|
if self.policy.use_continuous_act: |
|
|
|
bc_loss = torch.nn.functional.mse_loss(selected_actions, expert_actions) |
|
|
|
else: |
|
|
|
# TODO: add epsilon to log_probs |
|
|
|
log_prob_branches = ModelUtils.break_into_branches( |
|
|
|
log_probs, self.policy.act_size |
|
|
|
) |
|
|
|
|
|
|
if self.policy.use_continuous_act: |
|
|
|
expert_actions = ModelUtils.list_to_tensor(mini_batch_demo["actions"]) |
|
|
|
else: |
|
|
|
# one hot |
|
|
|
raw_expert_actions = ModelUtils.list_to_tensor( |
|
|
|
mini_batch_demo["actions"], dtype=torch.long |
|
|
|
) |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
memories = [] |
|
|
|
if self.policy.actor_critic.use_lstm: |
|
|
|
if self.policy.use_recurrent: |
|
|
|
memories = torch.zeros( |
|
|
|
1, self.n_sequences, self.policy.actor_critic.half_mem_size * 2 |
|
|
|
) |
|
|
|