浏览代码

more cleanups

/MLA-1734-demo-provider
Andrew Cohen 4 年前
当前提交
9ae8a720
共有 1 个文件被更改,包括 9 次插入10 次删除
  1. 19
      ml-agents/mlagents/trainers/torch/attention.py

19
ml-agents/mlagents/trainers/torch/attention.py


Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, key_size)
- key: of dimensions (batch_size, number_of_keys, key_size)
- value: of dimensions (batch_size, number_of_keys, value_size)
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
- The output: (batch_size, number_of_queries, output_size)
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
"""

query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
n_q: int,
n_k: int,
n_q: int = -1,
n_k: int = -1,
# This is to avoid using .size() when possible as Barracuda does not support
query = query.reshape(
b, n_q, self.n_heads, self.head_size

relevant "Entities".
"""
EPISLON = 1e-7
EPSILON = 1e-7
def __init__(
self,

key = self.fc_k(inp) # (b, n_k, emb)
value = self.fc_v(inp) # (b, n_k, emb)
output, _ = self.attention(
query, key, value, mask, self.max_num_ent, self.max_num_ent
query, key, value, self.max_num_ent, self.max_num_ent, mask
)
# Residual
output = self.fc_out(output) + inp

)
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
output = numerator / denominator
# Residual between x_self and the output of the module
return output
正在加载...
取消
保存