|
|
|
|
|
|
|
|
|
|
class DiverseNetwork(torch.nn.Module): |
|
|
|
EPSILON = 1e-10 |
|
|
|
STRENGTH = 0.1 |
|
|
|
STRENGTH = 1.0 |
|
|
|
state_encoder_settings = settings |
|
|
|
print("Settings : strength:", self.STRENGTH, " use_actions:", self._use_actions) |
|
|
|
# state_encoder_settings = settings |
|
|
|
state_encoder_settings = NetworkSettings(True) |
|
|
|
if state_encoder_settings.memory is not None: |
|
|
|
state_encoder_settings.memory = None |
|
|
|
logger.warning( |
|
|
|
|
|
|
for obs, spec in zip(obs_input, self._all_obs_specs) |
|
|
|
if spec.observation_type != ObservationType.GOAL_SIGNAL |
|
|
|
] |
|
|
|
|
|
|
|
action = self._action_flattener.forward(action_input) |
|
|
|
action = self._action_flattener.forward(action_input).reshape(-1, self._action_flattener.flattened_size) |
|
|
|
if detach_action: |
|
|
|
action = action.detach() |
|
|
|
hidden, _ = self._encoder.forward(tensor_obs, action) |
|
|
|
|
|
|
return rewards |
|
|
|
|
|
|
|
def loss(self, obs_input, action_input, masks, detach_action=True) -> torch.Tensor: |
|
|
|
return -ModelUtils.masked_mean( |
|
|
|
self.rewards(obs_input, action_input, detach_action), masks |
|
|
|
# print( ">>> ",obs_input[self._diverse_index][0],self.predict(obs_input, action_input, detach_action)[0], self.predict([x*0 for x in obs_input], action_input, detach_action * 0)[0] ) |
|
|
|
return - ModelUtils.masked_mean( |
|
|
|
self.rewards(obs_input, action_input, detach_action) , masks |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
v_backup = min_policy_qs[name] - torch.mean( |
|
|
|
branched_ent_bonus, axis=0 |
|
|
|
) |
|
|
|
) + self._mede_network.STRENGTH * self._mede_network.rewards(obs, act) |
|
|
|
print("The discrete case is much more complicated than that") |
|
|
|
# Add continuous entropy bonus to minimum Q |
|
|
|
if self._action_spec.continuous_size > 0: |
|
|
|
v_backup += torch.sum( |
|
|
|
|
|
|
cont_log_probs = log_probs.continuous_tensor |
|
|
|
batch_policy_loss += torch.mean( |
|
|
|
_cont_ent_coef * cont_log_probs - all_mean_q1.unsqueeze(1), dim=1 |
|
|
|
) - self._mede_network.STRENGTH * self._mede_network.rewards(obs, act) |
|
|
|
) |
|
|
|
batch_policy_loss += - self._mede_network.STRENGTH * self._mede_network.rewards(obs, act) |
|
|
|
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) |
|
|
|
|
|
|
|
return policy_loss |
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
cont_log_probs = log_probs.continuous_tensor |
|
|
|
target_current_diff = torch.sum( |
|
|
|
cont_log_probs + self.target_entropy.continuous, dim=1 |
|
|
|
) |
|
|
|
cont_log_probs, dim=1) + self.target_entropy.continuous |
|
|
|
# print(self.target_entropy.continuous, cont_log_probs, torch.sum( |
|
|
|
# cont_log_probs, dim=1) + self.target_entropy.continuous) |
|
|
|
# We update all the _cont_ent_coef as one block |
|
|
|
entropy_loss += -1 * ModelUtils.masked_mean( |
|
|
|
_cont_ent_coef * target_current_diff, loss_masks |
|
|
|
|
|
|
total_value_loss.backward() |
|
|
|
self.value_optimizer.step() |
|
|
|
|
|
|
|
|
|
|
|
ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) |
|
|
|
self.entropy_optimizer.zero_grad() |
|
|
|
entropy_loss.backward() |
|
|
|
|
|
|
torch.exp(self._log_ent_coef.continuous) |
|
|
|
).item(), |
|
|
|
"Policy/Learning Rate": decay_lr, |
|
|
|
"Policy/Entropy Loss": entropy_loss.item(), |
|
|
|
"Policy/MEDE Loss": mede_loss.item(), |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
"Optimizer:value_optimizer": self.value_optimizer, |
|
|
|
"Optimizer:entropy_optimizer": self.entropy_optimizer, |
|
|
|
"Optimizer:mede_optimizer": self._mede_optimizer, |
|
|
|
"Optimizer:mede_network": self._mede_network, |
|
|
|
} |
|
|
|
for reward_provider in self.reward_signals.values(): |
|
|
|
modules.update(reward_provider.get_modules()) |