浏览代码

Black format

/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
64b34759
共有 1 个文件被更改,包括 6 次插入2 次删除
  1. 8
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py

8
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


)
self.optimizer = torch.optim.Adam(
params, lr=self.trainer_settings.hyperparameters.learning_rate, weight_decay=1e-6
params,
lr=self.trainer_settings.hyperparameters.learning_rate,
weight_decay=1e-6,
)
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",

value_loss = torch.mean(torch.stack(value_losses))
return value_loss
def coma_regularizer_loss(self, values: Dict[str, torch.Tensor], baseline_values: Dict[str, torch.Tensor]):
def coma_regularizer_loss(
self, values: Dict[str, torch.Tensor], baseline_values: Dict[str, torch.Tensor]
):
reg_losses = []
for name, head in values.items():
reg_loss = torch.nn.functional.mse_loss(head, baseline_values[name])

正在加载...
取消
保存