|
|
|
|
|
|
from mlagents.trainers.torch.networks import NetworkBody |
|
|
|
from mlagents_envs.base_env import BehaviorSpec |
|
|
|
|
|
|
|
from mlagents.trainers.torch.layers import linear_layer, Initialization |
|
|
|
|
|
|
|
class DiverseNetworkVariational(torch.nn.Module): |
|
|
|
EPSILON = 1e-10 |
|
|
|
STRENGTH = 1.0 |
|
|
|
|
|
|
|
# gradient_penalty_weight = 10.0 |
|
|
|
z_size = 128 |
|
|
|
alpha = 0.0005 |
|
|
|
mutual_information = 10#0.5 |
|
|
|
EPSILON = 1e-7 |
|
|
|
initial_beta = 0.0 |
|
|
|
|
|
|
|
def __init__(self, specs: BehaviorSpec, settings) -> None: |
|
|
|
super().__init__() |
|
|
|
self._use_actions = True |
|
|
|
print("VARIATIONAL : Settings : strength:", self.STRENGTH, " use_actions:", self._use_actions, " mutual_information : ", self.mutual_information) |
|
|
|
# state_encoder_settings = settings |
|
|
|
state_encoder_settings = NetworkSettings(normalize=True, num_layers=1) |
|
|
|
if state_encoder_settings.memory is not None: |
|
|
|
state_encoder_settings.memory = None |
|
|
|
logger.warning( |
|
|
|
"memory was specified in network_settings but is not supported. It is being ignored." |
|
|
|
) |
|
|
|
self._action_flattener = ActionFlattener(specs.action_spec) |
|
|
|
new_spec = [ |
|
|
|
spec |
|
|
|
for spec in specs.observation_specs |
|
|
|
if spec.observation_type != ObservationType.GOAL_SIGNAL |
|
|
|
] |
|
|
|
diverse_spec = [ |
|
|
|
spec |
|
|
|
for spec in specs.observation_specs |
|
|
|
if spec.observation_type == ObservationType.GOAL_SIGNAL |
|
|
|
][0] |
|
|
|
|
|
|
|
print(" > ", new_spec, "\n\n\n", " >> ", diverse_spec) |
|
|
|
self._all_obs_specs = specs.observation_specs |
|
|
|
|
|
|
|
self.diverse_size = diverse_spec.shape[0] |
|
|
|
|
|
|
|
if self._use_actions: |
|
|
|
self._encoder = NetworkBody( |
|
|
|
new_spec, state_encoder_settings, self._action_flattener.flattened_size |
|
|
|
) |
|
|
|
else: |
|
|
|
self._encoder = NetworkBody(new_spec, state_encoder_settings) |
|
|
|
|
|
|
|
self._z_sigma = torch.nn.Parameter( |
|
|
|
torch.ones((self.z_size), dtype=torch.float), requires_grad=True |
|
|
|
) |
|
|
|
# self._z_mu_layer = linear_layer( |
|
|
|
# state_encoder_settings.hidden_units, |
|
|
|
# self.z_size, |
|
|
|
# kernel_init=Initialization.KaimingHeNormal, |
|
|
|
# kernel_gain=0.1, |
|
|
|
# ) |
|
|
|
self._beta = torch.nn.Parameter( |
|
|
|
torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False |
|
|
|
) |
|
|
|
|
|
|
|
self._last_layer = torch.nn.Linear( |
|
|
|
self.z_size, self.diverse_size |
|
|
|
) |
|
|
|
self._diverse_index = -1 |
|
|
|
self._max_index = len(specs.observation_specs) |
|
|
|
for i, spec in enumerate(specs.observation_specs): |
|
|
|
if spec.observation_type == ObservationType.GOAL_SIGNAL: |
|
|
|
self._diverse_index = i |
|
|
|
|
|
|
|
|
|
|
|
def predict(self, obs_input, action_input, detach_action=False, var_noise=True) -> torch.Tensor: |
|
|
|
# Convert to tensors |
|
|
|
tensor_obs = [ |
|
|
|
obs |
|
|
|
for obs, spec in zip(obs_input, self._all_obs_specs) |
|
|
|
if spec.observation_type != ObservationType.GOAL_SIGNAL |
|
|
|
] |
|
|
|
if self._use_actions: |
|
|
|
action = self._action_flattener.forward(action_input).reshape(-1, self._action_flattener.flattened_size) |
|
|
|
if detach_action: |
|
|
|
action = action.detach() |
|
|
|
hidden, _ = self._encoder.forward(tensor_obs, action) |
|
|
|
else: |
|
|
|
hidden, _ = self._encoder.forward(tensor_obs) |
|
|
|
|
|
|
|
# add a VAE (like in VAIL ?) |
|
|
|
# z_mu = self._z_mu_layer(hidden) |
|
|
|
z_mu = hidden#self._z_mu_layer(hidden) |
|
|
|
hidden = torch.normal(z_mu, self._z_sigma * var_noise) |
|
|
|
|
|
|
|
prediction = torch.softmax(self._last_layer(hidden), dim=1) |
|
|
|
return prediction, z_mu |
|
|
|
|
|
|
|
def copy_normalization(self, thing): |
|
|
|
self._encoder.processors[0].copy_normalization(thing.processors[1]) |
|
|
|
|
|
|
|
def rewards(self, obs_input, action_input, detach_action=False, var_noise=True) -> torch.Tensor: |
|
|
|
truth = obs_input[self._diverse_index] |
|
|
|
prediction, _ = self.predict(obs_input, action_input, detach_action, var_noise) |
|
|
|
rewards = torch.log(torch.sum((prediction * truth), dim=1) + self.EPSILON) |
|
|
|
return rewards |
|
|
|
|
|
|
|
def loss(self, obs_input, action_input, masks, detach_action=True, var_noise=True) -> torch.Tensor: |
|
|
|
# print( ">>> ",obs_input[self._diverse_index][0],self.predict(obs_input, action_input, detach_action)[0], self.predict([x*0 for x in obs_input], action_input, detach_action * 0)[0] ) |
|
|
|
base_loss = - ModelUtils.masked_mean( |
|
|
|
self.rewards(obs_input, action_input, detach_action, var_noise) , masks |
|
|
|
) |
|
|
|
|
|
|
|
_, mu = self.predict(obs_input, action_input, detach_action, var_noise) |
|
|
|
kl_loss = ModelUtils.masked_mean( |
|
|
|
-torch.sum( |
|
|
|
1 |
|
|
|
+ (self._z_sigma ** 2).log() |
|
|
|
- 0.5 * mu ** 2 |
|
|
|
# - 0.5 * mu_expert ** 2 |
|
|
|
- (self._z_sigma ** 2), |
|
|
|
dim=1, |
|
|
|
), masks |
|
|
|
) |
|
|
|
vail_loss = self._beta * (kl_loss - self.mutual_information) |
|
|
|
with torch.no_grad(): |
|
|
|
self._beta.data = torch.max( |
|
|
|
self._beta + self.alpha * (kl_loss - self.mutual_information), |
|
|
|
torch.tensor(0.0), |
|
|
|
) |
|
|
|
total_loss = base_loss + vail_loss |
|
|
|
|
|
|
|
return total_loss, base_loss, kl_loss, vail_loss, self._beta |
|
|
|
|
|
|
|
|
|
|
|
class DiverseNetwork(torch.nn.Module): |
|
|
|
EPSILON = 1e-10 |
|
|
|
|
|
|
def copy_normalization(self, thing): |
|
|
|
self._encoder.processors[0].copy_normalization(thing.processors[1]) |
|
|
|
|
|
|
|
def rewards(self, obs_input, action_input, detach_action=False) -> torch.Tensor: |
|
|
|
def rewards(self, obs_input, action_input, detach_action=False, var_noise=False) -> torch.Tensor: |
|
|
|
truth = obs_input[self._diverse_index] |
|
|
|
prediction = self.predict(obs_input, action_input, detach_action) |
|
|
|
rewards = torch.log(torch.sum((prediction * truth), dim=1) + self.EPSILON) |
|
|
|
|
|
|
self._critic.parameters() |
|
|
|
) |
|
|
|
|
|
|
|
self._mede_network = DiverseNetwork( |
|
|
|
# self._mede_network = DiverseNetwork( |
|
|
|
self._mede_network = DiverseNetworkVariational( |
|
|
|
self.policy.behavior_spec, self.policy.network_settings |
|
|
|
) |
|
|
|
self._mede_optimizer = torch.optim.Adam( |
|
|
|
|
|
|
min_policy_qs[name] |
|
|
|
- torch.sum(_cont_ent_coef * log_probs.continuous_tensor, dim=1) |
|
|
|
+ self._mede_network.STRENGTH |
|
|
|
* self._mede_network.rewards(obs, act) |
|
|
|
* self._mede_network.rewards(obs, act, var_noise=False) |
|
|
|
) |
|
|
|
value_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks |
|
|
|
|
|
|
batch_policy_loss += torch.mean( |
|
|
|
_cont_ent_coef * cont_log_probs - all_mean_q1.unsqueeze(1), dim=1 |
|
|
|
) |
|
|
|
batch_policy_loss += - self._mede_network.STRENGTH * self._mede_network.rewards(obs, act) |
|
|
|
batch_policy_loss += - self._mede_network.STRENGTH * self._mede_network.rewards(obs, act, var_noise=False) |
|
|
|
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) |
|
|
|
|
|
|
|
return policy_loss |
|
|
|
|
|
|
entropy_loss.backward() |
|
|
|
self.entropy_optimizer.step() |
|
|
|
|
|
|
|
mede_loss = self._mede_network.loss(current_obs, sampled_actions, masks) |
|
|
|
mede_loss, base_loss, kl_loss, vail_loss, beta = self._mede_network.loss(current_obs, sampled_actions, masks) |
|
|
|
# mede_loss = self._mede_network.loss(current_obs, sampled_actions, masks) |
|
|
|
ModelUtils.update_learning_rate(self._mede_optimizer, decay_lr) |
|
|
|
self._mede_optimizer.zero_grad() |
|
|
|
mede_loss.backward() |
|
|
|
|
|
|
"Policy/Learning Rate": decay_lr, |
|
|
|
"Policy/Entropy Loss": entropy_loss.item(), |
|
|
|
"Policy/MEDE Loss": mede_loss.item(), |
|
|
|
"Policy/MEDE Base": base_loss.item(), |
|
|
|
"Policy/MEDE Variational": vail_loss.item(), |
|
|
|
"Policy/MEDE KL": kl_loss.item(), |
|
|
|
"Policy/MEDE beta": beta.item(), |
|
|
|
} |
|
|
|
|
|
|
|
return update_stats |
|
|
|