浏览代码

Fixing merge conflicts

/develop/singular-embeddings
vincentpierre 4 年前
当前提交
edbac259
共有 1 个文件被更改,包括 59 次插入32 次删除
  1. 91
      ml-agents/mlagents/trainers/torch/attention.py

91
ml-agents/mlagents/trainers/torch/attention.py


from mlagents.torch_utils import torch
from typing import Tuple, Optional, List
from mlagents.trainers.torch.layers import LinearEncoder, Initialization, linear_layer
from mlagents.trainers.torch.model_serialization import exporting_to_onnx
from mlagents.trainers.exception import UnityTrainerException
class MultiHeadAttention(torch.nn.Module):

self,
x_self_size: int,
entity_sizes: List[int],
entity_num_max_elements: List[int],
entity_num_max_elements: Optional[List[int]],
A module that generates embeddings for a variable number of entities given "self"
encoding as well as a representation of each entity of each type. The expected
input of the forward method will be a list of tensors: each element of the
list corresponds to a type of entity, the dimension of each tensor is :
[batch_size, max_num_entities, entity_size]
:param x_self_size: The size of the self embedding that will be concatenated
with the entities
:param entity_sizes: The size of each entity type
:param entity_num_max_elements: A list of maximum number of entities, must be
the same length as the number of entity tensors that will be passed to the
forward method.
:param embedding_size: The size of the output embeddings
:param concat_self: If true, the x_self will be concatenated with the entities
before embedding
Constructs an EntityEmbeddings module.
:param x_self_size: Size of "self" entity.
:param entity_sizes: List of sizes for other entities. Should be of length
equivalent to the number of entities.
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
Needs to be assigned in order for model to be exportable to ONNX and Barracuda.
:param embedding_size: Embedding size for entity encoders.
:param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric
self-attention.
self.entity_num_max_elements: List[int] = entity_num_max_elements
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes)
if entity_num_max_elements is not None:
self.entity_num_max_elements = entity_num_max_elements
self.concat_self: bool = concat_self
# If not concatenating self, input to encoder is just entity size
if not concat_self:

# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))

class ResidualSelfAttention(torch.nn.Module):
"""
Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used
with an EntityEmbeddings module, to apply multi head self attention to encode information
about a "Self" and a list of relevant "Entities".
"""
self, embedding_size: int, total_max_elements: int, num_heads: int = 4
self,
embedding_size: int,
entity_num_max_elements: Optional[int] = None,
num_heads: int = 4,
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses
multi head self attention to encode information about a "Self" and a list of
relevant "Entities".
:param embedding_size: The size of the embeddings that will be generated (should be
dividable by the num_heads)
:param total_max_elements: The maximum total number of entities that can be passed to
the module
:param num_heads: The number of heads of the attention module
Constructs a ResidualSelfAttention module.
:param embedding_size: Embedding sizee for attention mechanism and
Q, K, V encoders.
:param entity_num_max_elements: A List of ints representing the maximum number
of elements in an entity sequence. Should be of length num_entities. Pass None to
not restrict the number of elements; however, this will make the module
unexportable to ONNX/Barracuda.
:param num_heads: Number of heads for Multi Head Self-Attention
self.max_num_ent = total_max_elements
self.max_num_ent: Optional[int] = None
if entity_num_max_elements is not None:
self.max_num_ent = entity_num_max_elements
self.attention = MultiHeadAttention(
num_heads=num_heads, embedding_size=embedding_size
)

query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)
value = self.fc_v(inp) # (b, n_k, emb)
output, _ = self.attention(
query, key, value, self.max_num_ent, self.max_num_ent, mask
)
# Only use max num if provided
if self.max_num_ent is not None:
num_ent = self.max_num_ent
else:
num_ent = inp.shape[1]
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
output, _ = self.attention(query, key, value, num_ent, num_ent, mask)
numerator = torch.sum(
output * (1 - mask).reshape(-1, self.max_num_ent, 1), dim=1
)
numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1), dim=1)
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
output = numerator / denominator
# Residual between x_self and the output of the module
正在加载...
取消
保存