浏览代码

Removing hacky layer norm

/develop/singular-embeddings
vincentpierre 4 年前
当前提交
e7024786
共有 1 个文件被更改,包括 0 次插入5 次删除
  1. 5
      ml-agents/mlagents/trainers/torch/networks.py

5
ml-agents/mlagents/trainers/torch/networks.py


encoder(x) for encoder, x in zip(self.var_len_encoders, var_len_inputs)
]
qkv = self.entities_embeddings(encoded_self, encoded_var_len)
mu_qkv = torch.mean(qkv, dim=2, keepdim=True)
qkv = (qkv - mu_qkv) / (
torch.sqrt(torch.mean((qkv - mu_qkv) ** 2, dim=2, keepdim=True))
+ 0.0001
)
attention_embedding = self.rsa(qkv, masks)
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)

正在加载...
取消
保存