from mlagents.torch_utils import torch
from typing import Tuple , Optional , List
from mlagents.trainers.torch.layers import LinearEncoder , Initialization , linear_layer
from mlagents.trainers.torch.model_serialization import exporting_to_onnx
from mlagents.trainers.exception import UnityTrainerException
class MultiHeadAttention ( torch . nn . Module ) :
class EntityEmbeddings ( torch . nn . Module ) :
"""
A module used to embed entities before passing them to a self - attention block .
Used in conjunction with ResidualSelfAttention to encode information about a self
and additional entities . Can also concatenate self to entities for ego - centric self -
attention . Inspired by architecture used in https : / / arxiv . org / pdf / 1909.07528 . pdf .
"""
def __init__ (
entity_num_max_elements : List [ int ] ,
entity_num_max_elements : Optional [ List [ int ] ] = None ,
"""
Constructs an EntityEmbeddings module .
: param x_self_size : Size of " self " entity .
: param entity_sizes : List of sizes for other entities . Should be of length
equivalent to the number of entities .
: param embedding_size : Embedding size for entity encoders .
: param entity_num_max_elements : Maximum elements in an entity , None for unrestricted .
Needs to be assigned in order for model to be exportable to ONNX and Barracuda .
: param concat_self : Whether to concatenate x_self to entites . Set True for ego - centric
self - attention .
"""
self . entity_num_max_elements : List [ int ] = entity_num_max_elements
self . entity_num_max_elements : List [ int ] = [ - 1 ] * len ( entity_sizes )
if entity_num_max_elements is not None :
self . entity_num_max_elements = entity_num_max_elements
self . concat_self : bool = concat_self
# If not concatenating self, input to encoder is just entity size
if not concat_self :
# Concatenate all observations with self
self_and_ent : List [ torch . Tensor ] = [ ]
for num_entities , ent in zip ( self . entity_num_max_elements , entities ) :
if num_entities < 0 :
if exporting_to_onnx . is_exporting ( ) :
raise UnityTrainerException (
" Trying to export an attention mechanism that doesn ' t have a set max \
number of elements . "
)
num_entities = ent . shape [ 1 ]
expanded_self = x_self . reshape ( - 1 , 1 , self . self_size )
expanded_self = torch . cat ( [ expanded_self ] * num_entities , dim = 1 )
self_and_ent . append ( torch . cat ( [ expanded_self , ent ] , dim = 2 ) )
class ResidualSelfAttention ( torch . nn . Module ) :
"""
A simple architecture inspired from https : / / arxiv . org / pdf / 1909.07528 . pdf that uses
multi head self attention to encode information about a " Self " and a list of
relevant " Entities " .
Residual self attention inspired from https : / / arxiv . org / pdf / 1909.07528 . pdf . Can be used
with an EntityEmbeddings module , to apply multi head self attention to encode information
about a " Self " and a list of relevant " Entities " .
"""
EPSILON = 1e-7
embedding_size : int ,
entity_num_max_elements : List [ int ] ,
entity_num_max_elements : Optional [ List [ int ] ] = None ,
"""
Constructs a ResidualSelfAttention module .
: param embedding_size : Embedding sizee for attention mechanism and
Q , K , V encoders .
: param entity_num_max_elements : A List of ints representing the maximum number
of elements in an entity sequence . Should be of length num_entities . Pass None to
not restrict the number of elements ; however , this will make the module
unexportable to ONNX / Barracuda .
: param num_heads : Number of heads for Multi Head Self - Attention
"""
self . entity_num_max_elements : List [ int ] = entity_num_max_elements
self . max_num_ent = sum ( entity_num_max_elements )
self . max_num_ent : Optional [ int ] = None
if entity_num_max_elements is not None :
_entity_num_max_elements = entity_num_max_elements
self . max_num_ent = sum ( _entity_num_max_elements )
self . attention = MultiHeadAttention (
num_heads = num_heads , embedding_size = embedding_size
)
query = self . fc_q ( inp ) # (b, n_q, emb)
key = self . fc_k ( inp ) # (b, n_k, emb)
value = self . fc_v ( inp ) # (b, n_k, emb)
output , _ = self . attention (
query , key , value , self . max_num_ent , self . max_num_ent , mask
)
# Only use max num if provided
if self . max_num_ent is not None :
num_ent = self . max_num_ent
else :
num_ent = inp . shape [ 1 ]
if exporting_to_onnx . is_exporting ( ) :
raise UnityTrainerException (
" Trying to export an attention mechanism that doesn ' t have a set max \
number of elements . "
)
output , _ = self . attention ( query , key , value , num_ent , num_ent , mask )
numerator = torch . sum (
output * ( 1 - mask ) . reshape ( - 1 , self . max_num_ent , 1 ) , dim = 1
)
numerator = torch . sum ( output * ( 1 - mask ) . reshape ( - 1 , num_ent , 1 ) , dim = 1 )
denominator = torch . sum ( 1 - mask , dim = 1 , keepdim = True ) + self . EPSILON
output = numerator / denominator
# Residual between x_self and the output of the module