|
|
|
|
|
|
key: torch.Tensor, |
|
|
|
value: torch.Tensor, |
|
|
|
key_mask: Optional[torch.Tensor] = None, |
|
|
|
number_of_keys: int = -1, |
|
|
|
number_of_queries: int = -1, |
|
|
|
n_q: int = -1, |
|
|
|
n_k: int = -1, |
|
|
|
n_q = number_of_queries |
|
|
|
n_k = number_of_keys |
|
|
|
|
|
|
|
query = query.reshape( |
|
|
|
b, n_q, self.n_heads, self.head_size |
|
|
|