浏览代码

Manchausen RL

/develop/manch
Ervin Teng 4 年前
当前提交
2fc23737
共有 1 个文件被更改,包括 17 次插入6 次删除
  1. 23
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

23
ml-agents/mlagents/trainers/sac/optimizer_torch.py


ModelUtils.soft_update(
self.policy.actor_critic.critic, self.target_network, 1.0
)
self.alpha = 0.9
self._log_ent_coef = torch.nn.Parameter(
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))),
requires_grad=True,

target_values: Dict[str, torch.Tensor],
dones: torch.Tensor,
rewards: Dict[str, torch.Tensor],
log_probs: torch.Tensor,
loss_masks: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q1_losses = []

q1_stream = q1_out[name].squeeze()
q2_stream = q2_out[name].squeeze()
with torch.no_grad():
q_backup = rewards[name] + (
(1.0 - self.use_dones_in_backup[name] * dones)
* self.gammas[i]
* target_values[name]
_ent_coef = torch.exp(self._log_ent_coef)
policy_bonus = (
self.alpha
* _ent_coef
* torch.sum((log_probs * log_probs.exp()), dim=1)
)
q_backup = (
rewards[name]
+ policy_bonus
+ (
(1.0 - self.use_dones_in_backup[name] * dones)
* self.gammas[i]
* target_values[name]
)
)
_q1_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks

dones = ModelUtils.list_to_tensor(batch["done"])
q1_loss, q2_loss = self.sac_q_loss(
q1_stream, q2_stream, target_values, dones, rewards, masks
q1_stream, q2_stream, target_values, dones, rewards, log_probs, masks
)
value_loss = self.sac_value_loss(
log_probs, value_estimates, q1p_out, q2p_out, masks, use_discrete

正在加载...
取消
保存