from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM , LinearEncoder
from mlagents.trainers.torch.model_serialization import exporting_to_onnx
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
ActivationFunction = Callable [ [ torch . Tensor ] , torch . Tensor ]
EncoderFunction = Callable [
else 0
)
(
self . visual_processors ,
self . vector_processors ,
encoder_input_size ,
) = ModelUtils . create_input_processors (
self . processors , self . embedding_sizes = ModelUtils . create_input_processors (
total_enc_size = encoder_input_size + encoded_act_size
total_enc_size = sum ( self . embedding_sizes ) + encoded_act_size
self . linear_encoder = LinearEncoder (
total_enc_size , network_settings . num_layers , self . h_size
)
else :
self . lstm = None # type: ignore
def update_normalization ( self , vec_inputs : List [ torch . Tensor ] ) - > None :
for vec_input , vec_enc in zip ( vec_inputs , self . vector_processors ) :
vec_enc . update_normalization ( vec_input )
def update_normalization ( self , buffer : AgentBuffer ) - > None :
obs = ObsUtil . from_buffer ( buffer , len ( self . processors ) )
for vec_input , enc in zip ( obs , self . processors ) :
if isinstance ( enc , VectorInput ) :
enc . update_normalization ( torch . as_tensor ( vec_input ) )
for n1 , n2 in zip ( self . vector_processors , other_network . vector_processors ) :
n1 . copy_normalization ( n2 )
for n1 , n2 in zip ( self . processors , other_network . processors ) :
if isinstance ( n1 , VectorInput ) and isinstance ( n2 , VectorInput ) :
n1 . copy_normalization ( n2 )
@property
def memory_size ( self ) - > int :
self ,
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
for idx , processor in enumerate ( self . vector_processors ) :
vec_input = vec_inputs [ idx ]
processed_vec = processor ( vec_input )
encodes . append ( processed_vec )
for idx , processor in enumerate ( self . visual_processors ) :
vis_input = vis_inputs [ idx ]
if not exporting_to_onnx . is_exporting ( ) :
vis_input = vis_input . permute ( [ 0 , 3 , 1 , 2 ] )
processed_vis = processor ( vis_input )
encodes . append ( processed_vis )
for idx , processor in enumerate ( self . processors ) :
obs_input = inputs [ idx ]
processed_obs = processor ( obs_input )
encodes . append ( processed_obs )
if len ( encodes ) == 0 :
raise Exception ( " No valid inputs to network. " )
def forward (
self ,
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
vec_inputs , vis_inputs , actions , memories , sequence_length
inputs , actions , memories , sequence_length
)
output = self . value_heads ( encoding )
return output , memories
@abc.abstractmethod
def update_normalization ( self , vector_obs : List [ torch . Tensor ] ) - > None :
def update_normalization ( self , buffer : AgentBuffer ) - > None :
"""
Updates normalization of Actor based on the provided List of vector obs .
: param vector_obs : A List of vector obs as tensors .
def get_action_stats (
self ,
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
masks : Optional [ torch . Tensor ] = None ,
memories : Optional [ torch . Tensor ] = None ,
sequence_length : int = 1 ,
@abc.abstractmethod
def critic_pass (
self ,
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
: param vec_inputs : List of vector inputs as tensors .
: param vis_inputs : List of visual inputs as tensors .
: param inputs : List of inputs as tensors .
: param memories : Tensor of memories , if using memory . Otherwise , None .
: returns : Dict of reward stream to output tensor for values .
"""
def get_action_stats_and_value (
self ,
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
masks : Optional [ torch . Tensor ] = None ,
memories : Optional [ torch . Tensor ] = None ,
sequence_length : int = 1 ,
"""
Returns sampled actions and value estimates .
If memory is enabled , return the memories as well .
: param vec_inputs : A List of vector inputs as tensors .
: param vis_inputs : A List of visual inputs as tensors .
: param inputs : A List of vector inputs as tensors .
: param masks : If using discrete actions , a Tensor of action masks .
: param memories : If using memory , a Tensor of initial memories .
: param sequence_length : If using memory , the sequence length .
def memory_size ( self ) - > int :
return self . network_body . memory_size
def update_normalization ( self , vector_obs : List [ torch . Tensor ] ) - > None :
self . network_body . update_normalization ( vector_obs )
def update_normalization ( self , buffer : AgentBuffer ) - > None :
self . network_body . update_normalization ( buffer )
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
masks : Optional [ torch . Tensor ] = None ,
memories : Optional [ torch . Tensor ] = None ,
sequence_length : int = 1 ,
vec_inputs , vis_inputs , memories = memories , sequence_length = sequence_length
inputs , memories = memories , sequence_length = sequence_length
)
action , log_probs , entropies = self . action_model ( encoding , masks )
return action , log_probs , entropies , memories
At this moment , torch . onnx . export ( ) doesn ' t accept None as tensor to be exported,
so the size of return tuple varies with action spec .
"""
# This code will convert the vec and vis obs into a list of inputs for the network
concatenated_vec_obs = vec_inputs [ 0 ]
inputs = [ ]
start = 0
end = 0
vis_index = 0
for i , enc in enumerate ( self . network_body . processors ) :
if isinstance ( enc , VectorInput ) :
# This is a vec_obs
vec_size = self . network_body . embedding_sizes [ i ]
end = start + vec_size
inputs . append ( concatenated_vec_obs [ : , start : end ] )
start = end
else :
inputs . append ( vis_inputs [ vis_index ] )
vis_index + = 1
# End of code to convert the vec and vis obs into a list of inputs for the network
vec_inputs , vis_inputs , memories = memories , sequence_length = 1
inputs , memories = memories , sequence_length = 1
)
(
def critic_pass (
self ,
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
vec_inputs , vis_inputs , memories = memories , sequence_length = sequence_length
inputs , memories = memories , sequence_length = sequence_length
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
actions : AgentAction ,
masks : Optional [ torch . Tensor ] = None ,
memories : Optional [ torch . Tensor ] = None ,
vec_inputs , vis_inputs , memories = memories , sequence_length = sequence_length
inputs , memories = memories , sequence_length = sequence_length
)
log_probs , entropies = self . action_model . evaluate ( encoding , masks , actions )
value_outputs = self . value_heads ( encoding )
self ,
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
masks : Optional [ torch . Tensor ] = None ,
memories : Optional [ torch . Tensor ] = None ,
sequence_length : int = 1 ,
encoding , memories = self . network_body (
vec_inputs , vis_ inputs, memories = memories , sequence_length = sequence_length
inputs , memories = memories , sequence_length = sequence_length
)
action , log_probs , entropies = self . action_model ( encoding , masks )
value_outputs = self . value_heads ( encoding )
def critic_pass (
self ,
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
vec_inputs , vis_inputs , memories = critic_mem , sequence_length = sequence_length
inputs , memories = critic_mem , sequence_length = sequence_length
)
if actor_mem is not None :
# Make memories with the actor mem unchanged
def get_stats_and_value (
self ,
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
actions : AgentAction ,
masks : Optional [ torch . Tensor ] = None ,
memories : Optional [ torch . Tensor ] = None ,
encoding , actor_mem_outs = self . network_body (
vec_inputs , vis_ inputs, memories = actor_mem , sequence_length = sequence_length
inputs , memories = actor_mem , sequence_length = sequence_length
vec_inputs , vis_ inputs, memories = critic_mem , sequence_length = sequence_length
inputs , memories = critic_mem , sequence_length = sequence_length
)
return log_probs , entropies , value_outputs
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
masks : Optional [ torch . Tensor ] = None ,
memories : Optional [ torch . Tensor ] = None ,
sequence_length : int = 1 ,
vec_inputs ,
vis_inputs ,
masks = masks ,
memories = actor_mem ,
sequence_length = sequence_length ,
inputs , masks = masks , memories = actor_mem , sequence_length = sequence_length
)
if critic_mem is not None :
# Make memories with the actor mem unchanged
def get_action_stats_and_value (
self ,
vec_inputs : List [ torch . Tensor ] ,
vis_inputs : List [ torch . Tensor ] ,
inputs : List [ torch . Tensor ] ,
masks : Optional [ torch . Tensor ] = None ,
memories : Optional [ torch . Tensor ] = None ,
sequence_length : int = 1 ,
actor_mem , critic_mem = self . _get_actor_critic_mem ( memories )
encoding , actor_mem_outs = self . network_body (
vec_inputs , vis_ inputs, memories = actor_mem , sequence_length = sequence_length
inputs , memories = actor_mem , sequence_length = sequence_length
vec_inputs , vis_ inputs, memories = critic_mem , sequence_length = sequence_length
inputs , memories = critic_mem , sequence_length = sequence_length
)
if self . use_lstm :
mem_out = torch . cat ( [ actor_mem_outs , critic_mem_outs ] , dim = - 1 )
def update_normalization ( self , vector_obs : List [ torch . Tensor ] ) - > None :
super ( ) . update_normalization ( vector_obs )
self . critic . network_body . update_normalization ( vector_obs )
def update_normalization ( self , buffer : AgentBuffer ) - > None :
super ( ) . update_normalization ( buffer )
self . critic . network_body . update_normalization ( buffer )
class GlobalSteps ( nn . Module ) :