浏览代码

add weight decay to trainers

/develop/weight-decay
Andrew Cohen 4 年前
当前提交
1bc2ff96
共有 3 个文件被更改,包括 10 次插入3 次删除
  1. 4
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  2. 8
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  3. 1
      ml-agents/mlagents/trainers/settings.py

4
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


)
self.optimizer = torch.optim.Adam(
params, lr=self.trainer_settings.hyperparameters.learning_rate
params,
lr=self.trainer_settings.hyperparameters.learning_rate,
weight_decay=self.trainer_settings.hyperparameters.weight_decay,
)
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",

8
ml-agents/mlagents/trainers/sac/optimizer_torch.py


self.trainer_settings.max_steps,
)
self.policy_optimizer = torch.optim.Adam(
policy_params, lr=hyperparameters.learning_rate
policy_params,
lr=hyperparameters.learning_rate,
weight_decay=hyperparameters.weight_decay,
value_params, lr=hyperparameters.learning_rate
value_params,
lr=hyperparameters.learning_rate,
weight_decay=hyperparameters.weight_decay,
)
self.entropy_optimizer = torch.optim.Adam(
self._log_ent_coef.parameters(), lr=hyperparameters.learning_rate

1
ml-agents/mlagents/trainers/settings.py


buffer_size: int = 10240
learning_rate: float = 3.0e-4
learning_rate_schedule: ScheduleType = ScheduleType.CONSTANT
weight_decay: float = 0.0
@attr.s(auto_attribs=True)

正在加载...
取消
保存