浏览代码

hallway collab exps on cloud

/comms-grad
Andrew Cohen 4 年前
当前提交
c843e3d4
共有 2 个文件被更改,包括 5 次插入2 次删除
  1. 5
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  2. 2
      ml-agents/mlagents/trainers/torch/distributions.py

5
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


memories=memories,
seq_len=self.policy.sequence_length,
)
obs[-1] = comms[0]
# this is a little bit of a hack but is whats recommended in the
# gumbel softmax documentation
one_hot_diff_comms = obs[-1] - comms[1].detach() + comms[1]
obs[-1] = one_hot_diff_comms
log_probs, entropy, values = self.policy.evaluate_actions(
obs,

2
ml-agents/mlagents/trainers/torch/distributions.py


logits = branch(inputs)
norm_logits = self._mask_branch(logits, masks[idx])
distribution = torch.nn.functional.gumbel_softmax(
norm_logits, hard=True, dim=1
norm_logits, hard=False, dim=1
)
branch_distributions.append(distribution)
return branch_distributions

正在加载...
取消
保存