|
|
|
|
|
|
:param num_sequences: Number of sequences to process. |
|
|
|
:return: Results of update. |
|
|
|
""" |
|
|
|
with torch.autograd.detect_anomaly(): |
|
|
|
# Get decayed parameters |
|
|
|
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|
|
|
decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step()) |
|
|
|
decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) |
|
|
|
returns = {} |
|
|
|
old_values = {} |
|
|
|
for name in self.reward_signals: |
|
|
|
old_values[name] = ModelUtils.list_to_tensor( |
|
|
|
batch[f"{name}_value_estimates"] |
|
|
|
) |
|
|
|
returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"]) |
|
|
|
# Get decayed parameters |
|
|
|
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|
|
|
decay_eps = self.decay_epsilon.get_value(self.policy.get_current_step()) |
|
|
|
decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) |
|
|
|
returns = {} |
|
|
|
old_values = {} |
|
|
|
for name in self.reward_signals: |
|
|
|
old_values[name] = ModelUtils.list_to_tensor( |
|
|
|
batch[f"{name}_value_estimates"] |
|
|
|
) |
|
|
|
returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"]) |
|
|
|
n_obs = len(self.policy.behavior_spec.observation_specs) |
|
|
|
current_obs = ObsUtil.from_buffer(batch, n_obs) |
|
|
|
# Convert to tensors |
|
|
|
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] |
|
|
|
n_obs = len(self.policy.behavior_spec.observation_specs) |
|
|
|
current_obs = ObsUtil.from_buffer(batch, n_obs) |
|
|
|
# Convert to tensors |
|
|
|
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] |
|
|
|
act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) |
|
|
|
actions = AgentAction.from_dict(batch) |
|
|
|
act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) |
|
|
|
actions = AgentAction.from_dict(batch) |
|
|
|
memories = [ |
|
|
|
ModelUtils.list_to_tensor(batch["memory"][i]) |
|
|
|
for i in range(0, len(batch["memory"]), self.policy.sequence_length) |
|
|
|
] |
|
|
|
if len(memories) > 0: |
|
|
|
memories = torch.stack(memories).unsqueeze(0) |
|
|
|
memories = [ |
|
|
|
ModelUtils.list_to_tensor(batch["memory"][i]) |
|
|
|
for i in range(0, len(batch["memory"]), self.policy.sequence_length) |
|
|
|
] |
|
|
|
if len(memories) > 0: |
|
|
|
memories = torch.stack(memories).unsqueeze(0) |
|
|
|
log_probs, entropy, values = self.policy.evaluate_actions( |
|
|
|
current_obs, |
|
|
|
masks=act_masks, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
seq_len=self.policy.sequence_length, |
|
|
|
) |
|
|
|
old_log_probs = ActionLogProbs.from_dict(batch).flatten() |
|
|
|
log_probs = log_probs.flatten() |
|
|
|
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|
|
|
value_loss = self.ppo_value_loss( |
|
|
|
values, old_values, returns, decay_eps, loss_masks |
|
|
|
) |
|
|
|
# print(log_probs) |
|
|
|
policy_loss = self.ppo_policy_loss( |
|
|
|
ModelUtils.list_to_tensor(batch["advantages"]), |
|
|
|
log_probs, |
|
|
|
old_log_probs, |
|
|
|
loss_masks, |
|
|
|
) |
|
|
|
loss = ( |
|
|
|
policy_loss |
|
|
|
+ 0.5 * value_loss |
|
|
|
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks) |
|
|
|
) |
|
|
|
log_probs, entropy, values = self.policy.evaluate_actions( |
|
|
|
current_obs, |
|
|
|
masks=act_masks, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
seq_len=self.policy.sequence_length, |
|
|
|
) |
|
|
|
old_log_probs = ActionLogProbs.from_dict(batch).flatten() |
|
|
|
log_probs = log_probs.flatten() |
|
|
|
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|
|
|
value_loss = self.ppo_value_loss( |
|
|
|
values, old_values, returns, decay_eps, loss_masks |
|
|
|
) |
|
|
|
policy_loss = self.ppo_policy_loss( |
|
|
|
ModelUtils.list_to_tensor(batch["advantages"]), |
|
|
|
log_probs, |
|
|
|
old_log_probs, |
|
|
|
loss_masks, |
|
|
|
) |
|
|
|
loss = ( |
|
|
|
policy_loss |
|
|
|
+ 0.5 * value_loss |
|
|
|
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks) |
|
|
|
) |
|
|
|
# Set optimizer learning rate |
|
|
|
ModelUtils.update_learning_rate(self.optimizer, decay_lr) |
|
|
|
self.optimizer.zero_grad() |
|
|
|
with torch.autograd.detect_anomaly(): |
|
|
|
loss.backward() |
|
|
|
# Set optimizer learning rate |
|
|
|
ModelUtils.update_learning_rate(self.optimizer, decay_lr) |
|
|
|
self.optimizer.zero_grad() |
|
|
|
loss.backward() |
|
|
|
|
|
|
|
self.optimizer.step() |
|
|
|
update_stats = { |
|
|
|