logger = get_logger ( __name__ )
from mlagents.trainers.torch.action_flattener import ActionFlattener
from mlagents_envs.base_env import ObservationType
from mlagents.trainers.torch.networks import NetworkBody
from mlagents_envs.base_env import BehaviorSpec
class DiverseNetwork ( torch . nn . Module ) :
EPSILON = 1e-10
STRENGTH = 0.1
def __init__ ( self , specs : BehaviorSpec , settings ) - > None :
super ( ) . __init__ ( )
self . _use_actions = True
state_encoder_settings = settings
if state_encoder_settings . memory is not None :
state_encoder_settings . memory = None
logger . warning (
" memory was specified in network_settings but is not supported. It is being ignored. "
)
self . _action_flattener = ActionFlattener ( specs . action_spec )
new_spec = [
spec
for spec in specs . observation_specs
if spec . observation_type != ObservationType . GOAL_SIGNAL
]
diverse_spec = [
spec
for spec in specs . observation_specs
if spec . observation_type == ObservationType . GOAL_SIGNAL
] [ 0 ]
print ( " > " , new_spec , " \n \n \n " , " >> " , diverse_spec )
self . _all_obs_specs = specs . observation_specs
self . diverse_size = diverse_spec . shape [ 0 ]
if self . _use_actions :
self . _encoder = NetworkBody (
new_spec , state_encoder_settings , self . _action_flattener . flattened_size
)
else :
self . _encoder = NetworkBody ( new_spec , state_encoder_settings )
self . _last_layer = torch . nn . Linear (
state_encoder_settings . hidden_units , self . diverse_size
)
self . _diverse_index = - 1
self . _max_index = len ( specs . observation_specs )
for i , spec in enumerate ( specs . observation_specs ) :
if spec . observation_type == ObservationType . GOAL_SIGNAL :
self . _diverse_index = i
def predict ( self , obs_input , action_input , detach_action = False ) - > torch . Tensor :
# Convert to tensors
tensor_obs = [
obs
for obs , spec in zip ( obs_input , self . _all_obs_specs )
if spec . observation_type != ObservationType . GOAL_SIGNAL
]
if self . _use_actions :
action = self . _action_flattener . forward ( action_input )
if detach_action :
action = action . detach ( )
hidden , _ = self . _encoder . forward ( tensor_obs , action )
else :
hidden , _ = self . _encoder . forward ( tensor_obs )
# add a VAE (like in VAIL ?)
prediction = torch . softmax ( self . _last_layer ( hidden ) , dim = 1 )
return prediction
def copy_normalization ( self , thing ) :
self . _encoder . processors [ 0 ] . copy_normalization ( thing . processors [ 1 ] )
def rewards ( self , obs_input , action_input , detach_action = False ) - > torch . Tensor :
truth = obs_input [ self . _diverse_index ]
prediction = self . predict ( obs_input , action_input , detach_action )
rewards = torch . log ( torch . sum ( ( prediction * truth ) , dim = 1 ) + self . EPSILON )
return rewards
def loss ( self , obs_input , action_input , masks , detach_action = True ) - > torch . Tensor :
return - ModelUtils . masked_mean (
self . rewards ( obs_input , action_input , detach_action ) , masks
)
class TorchSACOptimizer ( TorchOptimizer ) :
class PolicyValueNetwork ( nn . Module ) :
self . _critic . parameters ( )
)
self . _mede_network = DiverseNetwork (
self . policy . behavior_spec , self . policy . network_settings
)
self . _mede_optimizer = torch . optim . Adam (
list ( self . _mede_network . parameters ( ) ) , lr = hyperparameters . learning_rate
)
logger . debug ( " value_vars " )
for param in value_params :
logger . debug ( param . shape )
q1p_out : Dict [ str , torch . Tensor ] ,
q2p_out : Dict [ str , torch . Tensor ] ,
loss_masks : torch . Tensor ,
obs ,
act ,
) - > torch . Tensor :
min_policy_qs = { }
with torch . no_grad ( ) :
if self . _action_spec . discrete_size < = 0 :
for name in values . keys ( ) :
with torch . no_grad ( ) :
v_backup = min_policy_qs [ name ] - torch . sum (
_cont_ent_coef * log_probs . continuous_tensor , dim = 1
v_backup = (
min_policy_qs [ name ]
- torch . sum ( _cont_ent_coef * log_probs . continuous_tensor , dim = 1 )
+ self . _mede_network . STRENGTH
* self . _mede_network . rewards ( obs , act )
)
value_loss = 0.5 * ModelUtils . masked_mean (
torch . nn . functional . mse_loss ( values [ name ] , v_backup ) , loss_masks
log_probs : ActionLogProbs ,
q1p_outs : Dict [ str , torch . Tensor ] ,
loss_masks : torch . Tensor ,
obs ,
act ,
) - > torch . Tensor :
_cont_ent_coef , _disc_ent_coef = (
self . _log_ent_coef . continuous ,
cont_log_probs = log_probs . continuous_tensor
batch_policy_loss + = torch . mean (
_cont_ent_coef * cont_log_probs - all_mean_q1 . unsqueeze ( 1 ) , dim = 1
)
) - self . _mede_network . STRENGTH * self . _mede_network . rewards ( obs , act )
policy_loss = ModelUtils . masked_mean ( batch_policy_loss , loss_masks )
return policy_loss
self . target_network . network_body . copy_normalization (
self . policy . actor . network_body
)
self . _mede_network . copy_normalization ( self . policy . actor . network_body )
self . _critic . network_body . copy_normalization ( self . policy . actor . network_body )
sampled_actions , log_probs , _ , _ , = self . policy . actor . get_action_and_stats (
current_obs ,
q1_stream , q2_stream , target_values , dones , rewards , masks
)
value_loss = self . sac_value_loss (
log_probs , value_estimates , q1p_out , q2p_out , masks
log_probs ,
value_estimates ,
q1p_out ,
q2p_out ,
masks ,
current_obs ,
sampled_actions ,
)
policy_loss = self . sac_policy_loss (
log_probs , q1p_out , masks , current_obs , sampled_actions
policy_loss = self . sac_policy_loss ( log_probs , q1p_out , masks )
entropy_loss = self . sac_entropy_loss ( log_probs , masks )
total_value_loss = q1_loss + q2_loss
entropy_loss . backward ( )
self . entropy_optimizer . step ( )
mede_loss = self . _mede_network . loss ( current_obs , sampled_actions , masks )
ModelUtils . update_learning_rate ( self . _mede_optimizer , decay_lr )
self . _mede_optimizer . zero_grad ( )
mede_loss . backward ( )
self . _mede_optimizer . step ( )
# Update target network
ModelUtils . soft_update ( self . _critic , self . target_network , self . tau )
update_stats = {
torch . exp ( self . _log_ent_coef . continuous )
) . item ( ) ,
" Policy/Learning Rate " : decay_lr ,
" Policy/MEDE Loss " : mede_loss . item ( ) ,
}
return update_stats
" Optimizer:policy_optimizer " : self . policy_optimizer ,
" Optimizer:value_optimizer " : self . value_optimizer ,
" Optimizer:entropy_optimizer " : self . entropy_optimizer ,
" Optimizer:mede_optimizer " : self . _mede_optimizer ,
}
for reward_provider in self . reward_signals . values ( ) :
modules . update ( reward_provider . get_modules ( ) )