expert_batch = self . _demo_buffer . sample_mini_batch (
mini_batch . num_experiences , 1
)
loss , policy_mean_estimate , expert_mean_estimate , kl_loss = self . _discriminator_network . compute_loss (
loss , stats_dict = self . _discriminator_network . compute_loss (
stats_dict = {
" Losses/GAIL Discriminator Loss " : loss . detach ( ) . cpu ( ) . numpy ( ) ,
" Policy/GAIL Policy Estimate " : policy_mean_estimate . detach ( ) . cpu ( ) . numpy ( ) ,
" Policy/GAIL Expert Estimate " : expert_mean_estimate . detach ( ) . cpu ( ) . numpy ( ) ,
}
if self . _discriminator_network . use_vail :
stats_dict [ " Policy/GAIL Beta " ] = (
self . _discriminator_network . beta . detach ( ) . cpu ( ) . numpy ( )
)
stats_dict [ " Losses/GAIL KL Loss " ] = kl_loss . detach ( ) . cpu ( ) . numpy ( )
return stats_dict
def __init__ ( self , specs : BehaviorSpec , settings : GAILSettings ) - > None :
super ( ) . __init__ ( )
self . _policy_specs = specs
self . use_vail = settings . use_vail
self . _ use_vail = settings . use_vail
self . _settings = settings
state_encoder_settings = NetworkSettings (
estimator_input_size = settings . encoding_size
if settings . use_vail :
estimator_input_size = self . z_size
self . z_sigma = torch . nn . Parameter (
self . _ z_sigma = torch . nn . Parameter (
self . z_mu_layer = linear_layer (
self . _ z_mu_layer = linear_layer (
self . beta = torch . nn . Parameter (
self . _ beta = torch . nn . Parameter (
self . estimator = torch . nn . Sequential (
self . _ estimator = torch . nn . Sequential (
linear_layer ( estimator_input_size , 1 ) , torch . nn . Sigmoid ( )
)
hidden = self . encoder ( encoder_input )
z_mu : Optional [ torch . Tensor ] = None
if self . _settings . use_vail :
z_mu = self . z_mu_layer ( hidden )
hidden = torch . normal ( z_mu , self . z_sigma * use_vail_noise )
estimate = self . estimator ( hidden )
z_mu = self . _ z_mu_layer( hidden )
hidden = torch . normal ( z_mu , self . _ z_sigma * use_vail_noise )
estimate = self . _ estimator( hidden )
return estimate , z_mu
def compute_loss (
Given a policy mini_batch and an expert mini_batch , computes the loss of the discriminator .
"""
total_loss = torch . zeros ( 1 )
stats_dict : Dict [ str , np . ndarray ] = { }
policy_estimate , policy_mu = self . compute_estimate (
policy_batch , use_vail_noise = True
)
loss = - (
torch . log ( expert_estimate * ( 1 - self . EPSILON ) )
+ torch . log ( 1.0 - policy_estimate * ( 1 - self . EPSILON ) )
stats_dict [ " Policy/GAIL Policy Estimate " ] = (
policy_estimate . mean ( ) . detach ( ) . cpu ( ) . numpy ( )
)
stats_dict [ " Policy/GAIL Expert Estimate " ] = (
expert_estimate . mean ( ) . detach ( ) . cpu ( ) . numpy ( )
)
discriminator_loss = - (
torch . log ( expert_estimate + self . EPSILON )
+ torch . log ( 1.0 - policy_estimate + self . EPSILON )
kl_loss : Optional [ torch . Tensor ] = None
stats_dict [ " Losses/GAIL Loss " ] = discriminator_loss . detach ( ) . cpu ( ) . numpy ( )
total_loss + = discriminator_loss
+ ( self . z_sigma * * 2 ) . log ( )
+ ( self . _z_sigma * * 2 ) . log ( )
- ( self . z_sigma * * 2 ) ,
- ( self . _z_sigma * * 2 ) ,
vail_loss = self . beta * ( kl_loss - self . mutual_information )
vail_loss = self . _beta * ( kl_loss - self . mutual_information )
self . beta . data = torch . max (
self . beta + self . alpha * ( kl_loss - self . mutual_information ) ,
self . _beta . data = torch . max (
self . _beta + self . alpha * ( kl_loss - self . mutual_information ) ,
loss + = vail_loss
total_loss + = vail_loss
stats_dict [ " Policy/GAIL Beta " ] = self . _beta . detach ( ) . cpu ( ) . numpy ( )
stats_dict [ " Losses/GAIL KL Loss " ] = kl_loss . detach ( ) . cpu ( ) . numpy ( )
loss + = self . gradient_penalty_weight * self . compute_gradient_magnitude (
policy_batch , expert_batch
total_loss + = (
self . gradient_penalty_weight
* self . compute_gradient_magnitude ( policy_batch , expert_batch )
return loss , torch . mean ( policy_estimate ) , torch . mean ( expert_estimate ) , kl_loss
return total_loss , stats_dict
def compute_gradient_magnitude (
self , policy_batch : AgentBuffer , expert_batch : AgentBuffer
hidden = self . encoder ( encoder_input )
if self . _settings . use_vail :
use_vail_noise = True
z_mu = self . z_mu_layer ( hidden )
hidden = torch . normal ( z_mu , self . z_sigma * use_vail_noise )
hidden = self . estimator ( hidden )
z_mu = self . _ z_mu_layer( hidden )
hidden = torch . normal ( z_mu , self . _ z_sigma * use_vail_noise )
hidden = self . _ estimator( hidden )
estimate = torch . mean ( torch . sum ( hidden , dim = 1 ) )
gradient = torch . autograd . grad ( estimate , encoder_input ) [ 0 ]
# Norm's gradient could be NaN at 0. Use our own safe_norm