浏览代码

Adding a variational version

/exp-diverse-behavior
vincentpierre 3 年前
当前提交
b4f30613
共有 1 个文件被更改,包括 141 次插入5 次删除
  1. 146
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

146
ml-agents/mlagents/trainers/sac/optimizer_torch.py


from mlagents.trainers.torch.networks import NetworkBody
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.torch.layers import linear_layer, Initialization
class DiverseNetworkVariational(torch.nn.Module):
EPSILON = 1e-10
STRENGTH = 1.0
# gradient_penalty_weight = 10.0
z_size = 128
alpha = 0.0005
mutual_information = 10#0.5
EPSILON = 1e-7
initial_beta = 0.0
def __init__(self, specs: BehaviorSpec, settings) -> None:
super().__init__()
self._use_actions = True
print("VARIATIONAL : Settings : strength:", self.STRENGTH, " use_actions:", self._use_actions, " mutual_information : ", self.mutual_information)
# state_encoder_settings = settings
state_encoder_settings = NetworkSettings(normalize=True, num_layers=1)
if state_encoder_settings.memory is not None:
state_encoder_settings.memory = None
logger.warning(
"memory was specified in network_settings but is not supported. It is being ignored."
)
self._action_flattener = ActionFlattener(specs.action_spec)
new_spec = [
spec
for spec in specs.observation_specs
if spec.observation_type != ObservationType.GOAL_SIGNAL
]
diverse_spec = [
spec
for spec in specs.observation_specs
if spec.observation_type == ObservationType.GOAL_SIGNAL
][0]
print(" > ", new_spec, "\n\n\n", " >> ", diverse_spec)
self._all_obs_specs = specs.observation_specs
self.diverse_size = diverse_spec.shape[0]
if self._use_actions:
self._encoder = NetworkBody(
new_spec, state_encoder_settings, self._action_flattener.flattened_size
)
else:
self._encoder = NetworkBody(new_spec, state_encoder_settings)
self._z_sigma = torch.nn.Parameter(
torch.ones((self.z_size), dtype=torch.float), requires_grad=True
)
# self._z_mu_layer = linear_layer(
# state_encoder_settings.hidden_units,
# self.z_size,
# kernel_init=Initialization.KaimingHeNormal,
# kernel_gain=0.1,
# )
self._beta = torch.nn.Parameter(
torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False
)
self._last_layer = torch.nn.Linear(
self.z_size, self.diverse_size
)
self._diverse_index = -1
self._max_index = len(specs.observation_specs)
for i, spec in enumerate(specs.observation_specs):
if spec.observation_type == ObservationType.GOAL_SIGNAL:
self._diverse_index = i
def predict(self, obs_input, action_input, detach_action=False, var_noise=True) -> torch.Tensor:
# Convert to tensors
tensor_obs = [
obs
for obs, spec in zip(obs_input, self._all_obs_specs)
if spec.observation_type != ObservationType.GOAL_SIGNAL
]
if self._use_actions:
action = self._action_flattener.forward(action_input).reshape(-1, self._action_flattener.flattened_size)
if detach_action:
action = action.detach()
hidden, _ = self._encoder.forward(tensor_obs, action)
else:
hidden, _ = self._encoder.forward(tensor_obs)
# add a VAE (like in VAIL ?)
# z_mu = self._z_mu_layer(hidden)
z_mu = hidden#self._z_mu_layer(hidden)
hidden = torch.normal(z_mu, self._z_sigma * var_noise)
prediction = torch.softmax(self._last_layer(hidden), dim=1)
return prediction, z_mu
def copy_normalization(self, thing):
self._encoder.processors[0].copy_normalization(thing.processors[1])
def rewards(self, obs_input, action_input, detach_action=False, var_noise=True) -> torch.Tensor:
truth = obs_input[self._diverse_index]
prediction, _ = self.predict(obs_input, action_input, detach_action, var_noise)
rewards = torch.log(torch.sum((prediction * truth), dim=1) + self.EPSILON)
return rewards
def loss(self, obs_input, action_input, masks, detach_action=True, var_noise=True) -> torch.Tensor:
# print( ">>> ",obs_input[self._diverse_index][0],self.predict(obs_input, action_input, detach_action)[0], self.predict([x*0 for x in obs_input], action_input, detach_action * 0)[0] )
base_loss = - ModelUtils.masked_mean(
self.rewards(obs_input, action_input, detach_action, var_noise) , masks
)
_, mu = self.predict(obs_input, action_input, detach_action, var_noise)
kl_loss = ModelUtils.masked_mean(
-torch.sum(
1
+ (self._z_sigma ** 2).log()
- 0.5 * mu ** 2
# - 0.5 * mu_expert ** 2
- (self._z_sigma ** 2),
dim=1,
), masks
)
vail_loss = self._beta * (kl_loss - self.mutual_information)
with torch.no_grad():
self._beta.data = torch.max(
self._beta + self.alpha * (kl_loss - self.mutual_information),
torch.tensor(0.0),
)
total_loss = base_loss + vail_loss
return total_loss, base_loss, kl_loss, vail_loss, self._beta
class DiverseNetwork(torch.nn.Module):
EPSILON = 1e-10

def copy_normalization(self, thing):
self._encoder.processors[0].copy_normalization(thing.processors[1])
def rewards(self, obs_input, action_input, detach_action=False) -> torch.Tensor:
def rewards(self, obs_input, action_input, detach_action=False, var_noise=False) -> torch.Tensor:
truth = obs_input[self._diverse_index]
prediction = self.predict(obs_input, action_input, detach_action)
rewards = torch.log(torch.sum((prediction * truth), dim=1) + self.EPSILON)

self._critic.parameters()
)
self._mede_network = DiverseNetwork(
# self._mede_network = DiverseNetwork(
self._mede_network = DiverseNetworkVariational(
self.policy.behavior_spec, self.policy.network_settings
)
self._mede_optimizer = torch.optim.Adam(

min_policy_qs[name]
- torch.sum(_cont_ent_coef * log_probs.continuous_tensor, dim=1)
+ self._mede_network.STRENGTH
* self._mede_network.rewards(obs, act)
* self._mede_network.rewards(obs, act, var_noise=False)
)
value_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks

batch_policy_loss += torch.mean(
_cont_ent_coef * cont_log_probs - all_mean_q1.unsqueeze(1), dim=1
)
batch_policy_loss += - self._mede_network.STRENGTH * self._mede_network.rewards(obs, act)
batch_policy_loss += - self._mede_network.STRENGTH * self._mede_network.rewards(obs, act, var_noise=False)
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
return policy_loss

entropy_loss.backward()
self.entropy_optimizer.step()
mede_loss = self._mede_network.loss(current_obs, sampled_actions, masks)
mede_loss, base_loss, kl_loss, vail_loss, beta = self._mede_network.loss(current_obs, sampled_actions, masks)
# mede_loss = self._mede_network.loss(current_obs, sampled_actions, masks)
ModelUtils.update_learning_rate(self._mede_optimizer, decay_lr)
self._mede_optimizer.zero_grad()
mede_loss.backward()

"Policy/Learning Rate": decay_lr,
"Policy/Entropy Loss": entropy_loss.item(),
"Policy/MEDE Loss": mede_loss.item(),
"Policy/MEDE Base": base_loss.item(),
"Policy/MEDE Variational": vail_loss.item(),
"Policy/MEDE KL": kl_loss.item(),
"Policy/MEDE beta": beta.item(),
}
return update_stats

正在加载...
取消
保存