浏览代码

Proper mask mean for PPO

/develop/add-fire/memoryclass
Ervin Teng 4 年前
当前提交
1d4bc99e
共有 2 个文件被更改,包括 17 次插入7 次删除
  1. 14
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  2. 10
      ml-agents/mlagents/trainers/torch/utils.py

14
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


)
v_opt_a = (returns_tensor - head) ** 2
v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
masked_loss = torch.max(v_opt_a, v_opt_b) * loss_masks
value_loss = torch.mean(masked_loss)
value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))
return value_loss

p_opt_b = (
torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage
)
masked_loss = torch.min(p_opt_a, p_opt_b).flatten() * loss_masks
policy_loss = -torch.mean(masked_loss)
policy_loss = -1 * ModelUtils.masked_mean(
torch.min(p_opt_a, p_opt_b).flatten(), loss_masks
)
return policy_loss
@timed

if self.policy.use_continuous_act:
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1)
else:
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long)
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.bool)
memories = [
ModelUtils.list_to_tensor(batch["memory"][i])

memories=memories,
seq_len=self.policy.sequence_length,
)
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.float32)
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
value_loss = self.ppo_value_loss(
values, old_values, returns, decay_eps, loss_masks
)

loss = (
policy_loss
+ 0.5 * value_loss
- decay_bet * torch.mean(entropy.flatten() * loss_masks)
- decay_bet * ModelUtils.masked_mean(entropy.flatten(), loss_masks)
)
# Set optimizer learning rate

10
ml-agents/mlagents/trainers/torch/utils.py


else:
all_probs = torch.cat(all_probs_list, dim=-1)
return log_probs, entropies, all_probs
@staticmethod
def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
"""
Returns the mean of the tensor but ignoring the values specified by masks.
Used for masking out loss functions.
:param tensor: Tensor which needs mean computation.
:param masks: Boolean tensor of masks with same dimension as tensor.
"""
return (tensor * masks).sum() / masks.float().sum()
正在加载...
取消
保存