from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM , LinearEncoder , Initialization
from mlagents.trainers.torch.layers import LSTM , LinearEncoder
from mlagents.trainers.torch.attention import (
EntityEmbedding ,
ResidualSelfAttention ,
get_zero_entities_mask ,
)
from mlagents.trainers.torch.attention import EntityEmbedding , ResidualSelfAttention
ActivationFunction = Callable [ [ torch . Tensor ] , torch . Tensor ]
normalize = self . normalize ,
)
entity_num_max : int = 0
var_processors = [ p for p in self . processors if isinstance ( p , EntityEmbedding ) ]
for processor in var_processors :
entity_max : int = processor . entity_num_max_elements
# Only adds entity max if it was known at construction
if entity_max > 0 :
entity_num_max + = entity_max
if len ( var_processors ) > 0 :
if sum ( self . embedding_sizes ) :
self . x_self_encoder = LinearEncoder (
sum ( self . embedding_sizes ) ,
1 ,
self . h_size ,
kernel_init = Initialization . Normal ,
kernel_gain = ( 0.125 / self . h_size ) * * 0.5 ,
)
self . rsa = ResidualSelfAttention ( self . h_size , entity_num_max )
self . rsa , self . x_self_encoder = ModelUtils . create_residual_self_attention (
self . processors , self . embedding_sizes , self . h_size
)
if self . rsa is not None :
total_enc_size = sum ( self . embedding_sizes ) + self . h_size
else :
total_enc_size = sum ( self . embedding_sizes )
memories : Optional [ torch . Tensor ] = None ,
sequence_length : int = 1 ,
) - > Tuple [ torch . Tensor , torch . Tensor ] :
encodes = [ ]
var_len_processor_inputs : List [ Tuple [ nn . Module , torch . Tensor ] ] = [ ]
for idx , processor in enumerate ( self . processors ) :
if not isinstance ( processor , EntityEmbedding ) :
# The input can be encoded without having to process other inputs
obs_input = inputs [ idx ]
processed_obs = processor ( obs_input )
encodes . append ( processed_obs )
else :
var_len_processor_inputs . append ( ( processor , inputs [ idx ] ) )
if len ( encodes ) != 0 :
encoded_self = torch . cat ( encodes , dim = 1 )
input_exist = True
else :
input_exist = False
if len ( var_len_processor_inputs ) > 0 :
# Some inputs need to be processed with a variable length encoder
masks = get_zero_entities_mask ( [ p_i [ 1 ] for p_i in var_len_processor_inputs ] )
embeddings : List [ torch . Tensor ] = [ ]
processed_self = self . x_self_encoder ( encoded_self ) if input_exist else None
for processor , var_len_input in var_len_processor_inputs :
embeddings . append ( processor ( processed_self , var_len_input ) )
qkv = torch . cat ( embeddings , dim = 1 )
attention_embedding = self . rsa ( qkv , masks )
if not input_exist :
encoded_self = torch . cat ( [ attention_embedding ] , dim = 1 )
input_exist = True
else :
encoded_self = torch . cat ( [ encoded_self , attention_embedding ] , dim = 1 )
if not input_exist :
raise Exception (
" The trainer was unable to process any of the provided inputs. "
" Make sure the trained agents has at least one sensor attached to them. "
)
encoded_self = ModelUtils . encode_observations (
inputs , self . processors , self . rsa , self . x_self_encoder
)
if actions is not None :
encoded_self = torch . cat ( [ encoded_self , actions ] , dim = 1 )
encoding = self . linear_encoder ( encoded_self )
normalize = self . normalize ,
)
self . action_spec = action_spec
# This RSA and input are for variable length obs, not for multi-agentt.
(
self . input_rsa ,
self . input_x_self_encoder ,
) = ModelUtils . create_residual_self_attention (
self . processors , _input_size , self . h_size
)
if self . input_rsa is not None :
_input_size . append ( self . h_size )
# Modules for self-attention
# Modules for multi-agent self-attention
obs_only_ent_size = sum ( _input_size )
q_ent_size = (
sum ( _input_size )
attn_mask = only_first_obs_flat . isnan ( ) . type ( torch . FloatTensor )
return attn_mask
def _remove_nans_from_obs (
self , all_obs : List [ List [ torch . Tensor ] ] , attention_mask : torch . Tensor
) - > None :
"""
Helper function to remove NaNs from observations using an attention mask.
"""
for i_agent , single_agent_obs in enumerate ( all_obs ) :
for obs in single_agent_obs :
obs [
attention_mask . type ( torch . BoolTensor ) [ : , i_agent ] , : :
] = 0.0 # Remoove NaNs fast
def forward (
self ,
obs_only : List [ List [ torch . Tensor ] ] ,
concat_f_inp = [ ]
if obs :
obs_attn_mask = self . _get_masks_from_nans ( obs )
for i_agent , ( inputs , action ) in enumerate ( zip ( obs , actions ) ) :
encodes = [ ]
for idx , processor in enumerate ( self . processors ) :
obs_input = inputs [ idx ]
obs_input [
obs_attn_mask . type ( torch . BoolTensor ) [ : , i_agent ] , : :
] = 0.0 # Remoove NaNs fast
processed_obs = processor ( obs_input )
encodes . append ( processed_obs )
self . _remove_nans_from_obs ( obs , obs_attn_mask )
for inputs , action in zip ( obs , actions ) :
encoded = ModelUtils . encode_observations (
inputs , self . processors , self . input_rsa , self . input_x_self_encoder
)
torch . cat ( encodes , dim = - 1 ) ,
encoded ,
action . to_flat ( self . action_spec . discrete_branches ) ,
]
concat_f_inp . append ( torch . cat ( cat_encodes , dim = 1 ) )
concat_encoded_obs = [ ]
if obs_only :
obs_only_attn_mask = self . _get_masks_from_nans ( obs_only )
for i_agent , inputs in enumerate ( obs_only ) :
encodes = [ ]
for idx , processor in enumerate ( self . processors ) :
obs_input = inputs [ idx ]
obs_input [
obs_only_attn_mask . type ( torch . BoolTensor ) [ : , i_agent ] , : :
] = 0.0 # Remoove NaNs fast
processed_obs = processor ( obs_input )
encodes . append ( processed_obs )
concat_encoded_obs . append ( torch . cat ( encodes , dim = - 1 ) )
self . _remove_nans_from_obs ( obs_only , obs_only_attn_mask )
for inputs in obs_only :
encoded = ModelUtils . encode_observations (
inputs , self . processors , self . input_rsa , self . input_x_self_encoder
)
concat_encoded_obs . append ( encoded )
g_inp = torch . stack ( concat_encoded_obs , dim = 1 )
self_attn_masks . append ( obs_only_attn_mask )
self_attn_inputs . append ( self . obs_encoder ( None , g_inp ) )