浏览代码

address comments

/develop/add-fire/bc
Andrew Cohen 4 年前
当前提交
5f3a94cf
共有 1 个文件被更改,包括 2 次插入4 次删除
  1. 6
      ml-agents/mlagents/trainers/torch/components/bc/module.py

6
ml-agents/mlagents/trainers/torch/components/bc/module.py


self.decay_learning_rate = ModelUtils.DecayedValue(
learning_rate_schedule, self.current_lr, 1e-10, self._anneal_steps
)
params = list(self.policy.actor_critic.parameters())
params = self.policy.actor_critic.parameters()
self.optimizer = torch.optim.Adam(params, lr=self.current_lr)
_, self.demonstration_buffer = demo_to_buffer(

if self.policy.use_continuous_act:
bc_loss = torch.nn.functional.mse_loss(selected_actions, expert_actions)
else:
# TODO: add epsilon to log_probs
log_prob_branches = ModelUtils.break_into_branches(
log_probs, self.policy.act_size
)

if self.policy.use_continuous_act:
expert_actions = ModelUtils.list_to_tensor(mini_batch_demo["actions"])
else:
# one hot
raw_expert_actions = ModelUtils.list_to_tensor(
mini_batch_demo["actions"], dtype=torch.long
)

)
memories = []
if self.policy.actor_critic.use_lstm:
if self.policy.use_recurrent:
memories = torch.zeros(
1, self.n_sequences, self.policy.actor_critic.half_mem_size * 2
)

正在加载...
取消
保存