from typing import Optional , Dict
from typing import Optional , Dict , List , Tuple
import numpy as np
from mlagents.torch_utils import torch , default_device
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.networks import NetworkBody
from mlagents.trainers.torch.layers import linear_layer , Swish , Initialization
from mlagents.trainers.torch.layers import linear_layer , Initialization
from mlagents.trainers.settings import NetworkSettings , EncoderType
from mlagents.trainers.demo_loader import demo_to_buffer
self . _use_vail = settings . use_vail
self . _settings = settings
state_ encoder_settings = NetworkSettings (
encoder_settings = NetworkSettings (
normalize = False ,
hidden_units = settings . encoding_size ,
num_layers = 2 ,
self . _state_encoder = NetworkBody (
specs . observation_shapes , state_encoder_settings
)
encoder_input_size = settings . encoding_size
if settings . use_actions :
encoder_input_size + = (
self . _action_flattener . flattened_size + 1
) # + 1 is for done
self . encoder = torch . nn . Sequential (
linear_layer ( encoder_input_size , settings . encoding_size ) ,
Swish ( ) ,
linear_layer ( settings . encoding_size , settings . encoding_size ) ,
Swish ( ) ,
unencoded_size = (
self . _action_flattener . flattened_size + 1 if settings . use_actions else 0
) # +1 is for dones
self . encoder = NetworkBody (
specs . observation_shapes , encoder_settings , unencoded_size
)
estimator_input_size = settings . encoding_size
torch . as_tensor ( mini_batch [ " actions " ] , dtype = torch . float )
)
def get_state_encoding ( self , mini_batch : AgentBuffer ) - > torch . Tensor :
def get_state_inputs (
self , mini_batch : AgentBuffer
) - > Tuple [ List [ torch . Tensor ] , List [ torch . Tensor ] ] :
n_vis = len ( self . _state_encoder . visual_processors )
hidden , _ = self . _state_encoder . forward (
vec_inputs = [ torch . as_tensor ( mini_batch [ " vector_obs " ] , dtype = torch . float ) ] ,
vis_inputs = [
torch . as_tensor ( mini_batch [ " visual_obs %d " % i ] , dtype = torch . float )
for i in range ( n_vis )
] ,
n_vis = len ( self . encoder . visual_processors )
n_vec = len ( self . encoder . vector_processors )
vec_inputs = (
[ ModelUtils . list_to_tensor ( mini_batch [ " vector_obs " ] , dtype = torch . float ) ]
if n_vec > 0
else [ ]
return hidden
vis_inputs = [
ModelUtils . list_to_tensor ( mini_batch [ " visual_obs %d " % i ] , dtype = torch . float )
for i in range ( n_vis )
]
return vec_inputs , vis_inputs
def compute_estimate (
self , mini_batch : AgentBuffer , use_vail_noise : bool = False
: param use_vail_noise : Only when using VAIL : If true , will sample the code , if
false , will return the mean of the code .
"""
encoder_input = self . get_state_encoding ( mini_batch )
vec_inputs , vis_inputs = self . get_state_inputs ( mini_batch )
encoder_input = torch . cat ( [ encoder_input , actions , dones ] , dim = 1 )
hidden = self . encoder ( encoder_input )
action_inputs = torch . cat ( [ actions , dones ] , dim = 1 )
hidden , _ = self . encoder ( vec_inputs , vis_inputs , action_inputs )
else :
hidden , _ = self . encoder ( vec_inputs , vis_inputs )
z_mu : Optional [ torch . Tensor ] = None
if self . _settings . use_vail :
z_mu = self . _z_mu_layer ( hidden )
Gradient penalty from https : / / arxiv . org / pdf / 1704.00028 . Adds stability esp .
for off - policy . Compute gradients w . r . t randomly interpolated input .
"""
policy_obs = self . get_state_encoding ( policy_batch )
expert_obs = self . get_state_encoding ( expert_batch )
obs_epsilon = torch . rand ( policy_obs . shape )
encoder_input = obs_epsilon * policy_obs + ( 1 - obs_epsilon ) * expert_obs
policy_vec_inputs , policy_vis_inputs = self . get_state_inputs ( policy_batch )
expert_vec_inputs , expert_vis_inputs = self . get_state_inputs ( expert_batch )
interp_vec_inputs = [ ]
for policy_vec_input , expert_vec_input in zip (
policy_vec_inputs , expert_vec_inputs
) :
obs_epsilon = torch . rand ( policy_vec_input . shape )
interp_vec_input = (
obs_epsilon * policy_vec_input + ( 1 - obs_epsilon ) * expert_vec_input
)
interp_vec_input . requires_grad = True # For gradient calculation
interp_vec_inputs . append ( interp_vec_input )
interp_vis_inputs = [ ]
for policy_vis_input , expert_vis_input in zip (
policy_vis_inputs , expert_vis_inputs
) :
obs_epsilon = torch . rand ( policy_vis_input . shape )
interp_vis_input = (
obs_epsilon * policy_vis_input + ( 1 - obs_epsilon ) * expert_vis_input
)
interp_vis_input . requires_grad = True # For gradient calculation
interp_vis_inputs . append ( interp_vis_input )
if self . _settings . use_actions :
policy_action = self . get_action_input ( policy_batch )
expert_action = self . get_action_input ( expert_batch )
expert_batch [ " done " ] , dtype = torch . float
) . unsqueeze ( 1 )
dones_epsilon = torch . rand ( policy_dones . shape )
encoder_input = torch . cat (
action_inputs = torch . cat (
encoder_input ,
action_epsilon * policy_action
+ ( 1 - action_epsilon ) * expert_action ,
dones_epsilon * policy_dones + ( 1 - dones_epsilon ) * expert_dones ,
hidden = self . encoder ( encoder_input )
action_inputs . requires_grad = True
hidden , _ = self . encoder (
interp_vec_inputs , interp_vis_inputs , action_inputs
)
encoder_input = tuple (
interp_vec_inputs + interp_vis_inputs + [ action_inputs ]
)
else :
hidden , _ = self . encoder ( interp_vec_inputs , interp_vis_inputs )
encoder_input = tuple ( interp_vec_inputs + interp_vis_inputs )
gradient = torch . autograd . grad ( estimate , encoder_input , create_graph = True ) [ 0 ]
# Norm's gradient could be NaN at 0. Use our own safe_norm
safe_norm = ( torch . sum ( gradient * * 2 , dim = 1 ) + self . EPSILON ) . sqrt ( )