浏览代码

use hard=true gbsm

/comms-grad
Andrew Cohen 4 年前
当前提交
708ac9bf
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 2
      ml-agents/mlagents/trainers/torch/distributions.py

2
ml-agents/mlagents/trainers/torch/distributions.py


logits = branch(inputs)
norm_logits = self._mask_branch(logits, masks[idx])
distribution = torch.nn.functional.gumbel_softmax(
norm_logits, hard=False, dim=1
norm_logits, hard=True, dim=1
)
branch_distributions.append(distribution)
return branch_distributions

正在加载...
取消
保存