浏览代码

Merge branch 'develop-coma2-samenet' into develop-coma2-samenet-sum

/develop/coma2/samenet/sum
Ervin Teng 4 年前
当前提交
1cf27871
共有 1 个文件被更改,包括 10 次插入6 次删除
  1. 16
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py

16
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


)
self.optimizer = torch.optim.Adam(
params, lr=self.trainer_settings.hyperparameters.learning_rate, weight_decay=1e-6
params,
lr=self.trainer_settings.hyperparameters.learning_rate,
weight_decay=1e-6,
)
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",

value_loss = torch.mean(torch.stack(value_losses))
return value_loss
def coma_regularizer_loss(self, values: Dict[str, torch.Tensor], baseline_values: Dict[str, torch.Tensor]):
def coma_regularizer_loss(
self, values: Dict[str, torch.Tensor], baseline_values: Dict[str, torch.Tensor]
):
reg_losses = []
for name, head in values.items():
reg_loss = torch.nn.functional.mse_loss(head, baseline_values[name])

# Regularizer loss reduces bias between the baseline and values. Other
# regularizers are possible here.
regularizer_loss = self.coma_regularizer_loss(values, baseline_vals)
# regularizer_loss = self.coma_regularizer_loss(values, baseline_vals)
policy_loss = self.ppo_policy_loss(
ModelUtils.list_to_tensor(batch["advantages"]),

)
loss = (
policy_loss
+ 0.25 * (value_loss + baseline_loss)
+ 0.25 * regularizer_loss
+ 0.5 * (value_loss + 0.5 * baseline_loss)
# + 0.25 * regularizer_loss
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
)

"Losses/Value Loss": value_loss.item(),
# "Losses/Q Loss": q_loss.item(),
"Losses/Baseline Value Loss": baseline_loss.item(),
"Losses/Regularization Loss": regularizer_loss.item(),
# "Losses/Regularization Loss": regularizer_loss.item(),
"Policy/Learning Rate": decay_lr,
"Policy/Epsilon": decay_eps,
"Policy/Beta": decay_bet,

正在加载...
取消
保存