浏览代码

clean up args in mha

/MLA-1734-demo-provider
Andrew Cohen 4 年前
当前提交
3ca65063
共有 1 个文件被更改,包括 2 次插入4 次删除
  1. 6
      ml-agents/mlagents/trainers/torch/attention.py

6
ml-agents/mlagents/trainers/torch/attention.py


key: torch.Tensor,
value: torch.Tensor,
key_mask: Optional[torch.Tensor] = None,
number_of_keys: int = -1,
number_of_queries: int = -1,
n_q: int = -1,
n_k: int = -1,
n_q = number_of_queries
n_k = number_of_keys
query = query.reshape(
b, n_q, self.n_heads, self.head_size

正在加载...
取消
保存