from mlagents.torch_utils import torch , nn
from mlagents_envs.base_env import ActionSpec , ObservationSpec
from mlagents_envs.base_env import ActionSpec , ObservationSpec , ObservationType
from mlagents.trainers.settings import NetworkSettings , EncoderType
from mlagents.trainers.settings import NetworkSettings , EncoderType , ConditioningType
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM , LinearEncoder
from mlagents.trainers.torch.conditioning import ConditionalEncoder
from mlagents.trainers.torch.attention import (
EntityEmbedding ,
ResidualSelfAttention ,
self . normalize = normalize
self . _total_enc_size = total_enc_size
self . _total_goal_enc_size = 0
self . _goal_processor_indices : List [ int ] = [ ]
for i in range ( len ( observation_specs ) ) :
if observation_specs [ i ] . observation_type == ObservationType . GOAL :
self . _total_goal_enc_size + = self . embedding_sizes [ i ]
self . _goal_processor_indices . append ( i )
@property
def total_enc_size ( self ) - > int :
"""
@property
def total_goal_enc_size ( self ) - > int :
"""
Returns the total goal encoding size for this ObservationEncoder .
"""
return self . _total_goal_enc_size
def update_normalization ( self , buffer : AgentBuffer ) - > None :
obs = ObsUtil . from_buffer ( buffer , len ( self . processors ) )
"""
Encode observations using a list of processors and an RSA .
: param inputs : List of Tensors corresponding to a set of obs .
: param processors : a ModuleList of the input processors to be applied to these obs .
: param rsa : Optionally , an RSA to use for variable length obs .
: param x_self_encoder : Optionally , an encoder to use for x_self ( in this case , the non - variable inputs . ) .
"""
encodes = [ ]
var_len_processor_inputs : List [ Tuple [ nn . Module , torch . Tensor ] ] = [ ]
return encoded_self
def get_goal_encoding ( self , inputs : List [ torch . Tensor ] ) - > torch . Tensor :
"""
Encode observations corresponding to goals using a list of processors .
: param inputs : List of Tensors corresponding to a set of obs .
"""
encodes = [ ]
for idx in self . _goal_processor_indices :
processor = self . processors [ idx ]
if not isinstance ( processor , EntityEmbedding ) :
# The input can be encoded without having to process other inputs
obs_input = inputs [ idx ]
processed_obs = processor ( obs_input )
encodes . append ( processed_obs )
else :
raise UnityTrainerException (
" The one of the goals uses variable length observations. This use "
" case is not supported. "
)
if len ( encodes ) != 0 :
encoded = torch . cat ( encodes , dim = 1 )
else :
raise UnityTrainerException (
" Trainer was unable to process any of the goals provided as input. "
)
return encoded
class NetworkBody ( nn . Module ) :
def __init__ (
self . processors = self . observation_encoder . processors
total_enc_size = self . observation_encoder . total_enc_size
total_enc_size + = encoded_act_size
self . linear_encoder = LinearEncoder (
total_enc_size , network_settings . num_layers , self . h_size
)
if (
self . observation_encoder . total_goal_enc_size > 0
and network_settings . goal_conditioning_type == ConditioningType . HYPER
) :
self . _body_endoder = ConditionalEncoder (
total_enc_size ,
self . observation_encoder . total_goal_enc_size ,
self . h_size ,
network_settings . num_layers ,
1 ,
)
else :
self . _body_endoder = LinearEncoder (
total_enc_size , network_settings . num_layers , self . h_size
)
if self . use_lstm :
self . lstm = LSTM ( self . h_size , self . m_size )
encoded_self = self . observation_encoder ( inputs )
if actions is not None :
encoded_self = torch . cat ( [ encoded_self , actions ] , dim = 1 )
encoding = self . linear_encoder ( encoded_self )
if isinstance ( self . _body_endoder , ConditionalEncoder ) :
goal = self . observation_encoder . get_goal_encoding ( inputs )
encoding = self . _body_endoder ( encoded_self , goal )
else :
encoding = self . _body_endoder ( encoded_self )
if self . use_lstm :
# Resize to (batch, sequence length, encoding size)