浏览代码

Discrete and entrop coeff

/develop/sac-targetq
Ervin Teng 4 年前
当前提交
52efe509
共有 1 个文件被更改,包括 41 次插入101 次删除
  1. 142
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

142
ml-agents/mlagents/trainers/sac/optimizer_torch.py


from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.buffer import AgentBuffer
from mlagents_envs.timers import timed
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.settings import TrainerSettings, SACSettings
EPSILON = 1e-6 # Small value to avoid divide by zero

q2_out: Dict[str, torch.Tensor],
target_q1: Dict[str, torch.Tensor],
target_q2: Dict[str, torch.Tensor],
next_log_prob: torch.Tensor,
discrete: bool = False,
with torch.no_grad():
_ent_coef = torch.exp(self._log_ent_coef)
target_q = torch.min(target_q1[name].squeeze(), target_q2[name].squeeze())
if not discrete:
ent_bonus = -torch.sum(_ent_coef * next_log_prob, dim=1)
else:
branched_per_action_ent = ModelUtils.break_into_branches(
next_log_prob * next_log_prob.exp(), self.act_size
)
# We have to do entropy bonus per action branch
branched_ent_bonus = torch.stack(
[
torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True)
for i, _lp in enumerate(branched_per_action_ent)
]
)
ent_bonus = -torch.mean(branched_ent_bonus, axis=0).squeeze()
target_q = torch.min(
target_q1[name].squeeze(), target_q2[name].squeeze()
)
target_q = target_q + ent_bonus
q_backup = rewards[name] + (
(1.0 - self.use_dones_in_backup[name] * dones)
* self.gammas[i]

out=target_param.data,
)
def sac_value_loss(
self,
log_probs: torch.Tensor,
values: Dict[str, torch.Tensor],
q1p_out: Dict[str, torch.Tensor],
q2p_out: Dict[str, torch.Tensor],
loss_masks: torch.Tensor,
discrete: bool,
) -> torch.Tensor:
min_policy_qs = {}
with torch.no_grad():
_ent_coef = torch.exp(self._log_ent_coef)
for name in values.keys():
if not discrete:
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name])
else:
action_probs = log_probs.exp()
_branched_q1p = ModelUtils.break_into_branches(
q1p_out[name] * action_probs, self.act_size
)
_branched_q2p = ModelUtils.break_into_branches(
q2p_out[name] * action_probs, self.act_size
)
_q1p_mean = torch.mean(
torch.stack(
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p]
),
dim=0,
)
_q2p_mean = torch.mean(
torch.stack(
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p]
),
dim=0,
)
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)
value_losses = []
if not discrete:
for name in values.keys():
with torch.no_grad():
v_backup = min_policy_qs[name] - torch.sum(
_ent_coef * log_probs, dim=1
)
value_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks
)
value_losses.append(value_loss)
else:
branched_per_action_ent = ModelUtils.break_into_branches(
log_probs * log_probs.exp(), self.act_size
)
# We have to do entropy bonus per action branch
branched_ent_bonus = torch.stack(
[
torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True)
for i, _lp in enumerate(branched_per_action_ent)
]
)
for name in values.keys():
with torch.no_grad():
v_backup = min_policy_qs[name] - torch.mean(
branched_ent_bonus, axis=0
)
value_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(values[name], v_backup.squeeze()),
loss_masks,
)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))
if torch.isinf(value_loss).any() or torch.isnan(value_loss).any():
raise UnityTrainerException("Inf found")
return value_loss
def sac_policy_loss(
self,
log_probs: torch.Tensor,

self.target_network.q2_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
(
sampled_actions,
log_probs,
_,
_,
) = self._sample_actions(
(sampled_actions, log_probs, _, _) = self._sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

q2_stream = self._condense_q_streams(q2_out, actions)
with torch.no_grad():
(
next_actions,
_,
_,
_,
) = self._sample_actions(
(next_actions, next_log_probs, _, _) = self._sample_actions(
next_vec_obs,
next_vis_obs,
masks=act_masks,

)
q1_target, q2_target = self.value_network(
q1_target, q2_target = self.target_network(
next_actions,
next_actions if self.policy.use_continuous_act else None,
if not self.policy.use_continuous_act:
q1_target = self._condense_q_streams(q1_target, next_actions)
q2_target = self._condense_q_streams(q2_target, next_actions)
q1_stream, q2_stream, q1_target, q2_target, dones, rewards, masks
q1_stream,
q2_stream,
q1_target,
q2_target,
next_log_probs,
dones,
rewards,
masks,
discrete=not self.policy.use_continuous_act,
)
# value_loss = self.sac_value_loss(
# log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete

memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
all_log_probs: bool = False,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
"""

else:
actions = actions[:, 0, :]
return (
actions,
all_logs if all_log_probs else log_probs,
entropies,
memories,
)
return (actions, all_logs if all_log_probs else log_probs, entropies, memories)
正在加载...
取消
保存