浏览代码

Works with continuous

/develop/sac-targetq
Ervin Teng 4 年前
当前提交
5495b2b6
共有 1 个文件被更改,包括 73 次插入20 次删除
  1. 93
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

93
ml-agents/mlagents/trainers/sac/optimizer_torch.py


self.act_size,
)
self.target_network = ValueNetwork(
self.target_network = TorchSACOptimizer.PolicyValueNetwork(
self.policy.behavior_spec.action_type,
self.act_size,
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0)
self.soft_update(self.target_network, self.target_network, 1.0)
self._log_ent_coef = torch.nn.Parameter(
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))),

policy_params = list(self.policy.actor_critic.network_body.parameters()) + list(
self.policy.actor_critic.distribution.parameters()
)
value_params = list(self.value_network.parameters()) + list(
self.policy.actor_critic.critic.parameters()
)
value_params = list(self.value_network.parameters())
logger.debug("value_vars")
for param in value_params:

self,
q1_out: Dict[str, torch.Tensor],
q2_out: Dict[str, torch.Tensor],
target_values: Dict[str, torch.Tensor],
target_q1: Dict[str, torch.Tensor],
target_q2: Dict[str, torch.Tensor],
dones: torch.Tensor,
rewards: Dict[str, torch.Tensor],
loss_masks: torch.Tensor,

q1_stream = q1_out[name].squeeze()
q2_stream = q2_out[name].squeeze()
with torch.no_grad():
target_q = torch.min(target_q1[name].squeeze(), target_q2[name].squeeze())
* target_values[name]
* target_q
)
_q1_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks

self.value_network.q2_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
self.target_network.network_body.copy_normalization(
self.target_network.q1_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
self.target_network.q2_network.network_body.copy_normalization(
entropies,
sampled_values,
_,
) = self.policy.sample_actions(
) = self._sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

q2_stream = self._condense_q_streams(q2_out, actions)
with torch.no_grad():
target_values, _ = self.target_network(
(
next_actions,
_,
_,
_,
) = self._sample_actions(
memories=next_memories,
masks=act_masks,
memories=memories,
seq_len=self.policy.sequence_length,
all_log_probs=not self.policy.use_continuous_act,
)
q1_target, q2_target = self.value_network(
next_vec_obs,
next_vis_obs,
next_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)

q1_loss, q2_loss = self.sac_q_loss(
q1_stream, q2_stream, target_values, dones, rewards, masks
q1_stream, q2_stream, q1_target, q2_target, dones, rewards, masks
value_loss = self.sac_value_loss(
log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete
)
# value_loss = self.sac_value_loss(
# log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete
# )
total_value_loss = q1_loss + q2_loss + value_loss
total_value_loss = q1_loss + q2_loss
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)

self.entropy_optimizer.step()
# Update target network
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau)
self.soft_update(self.value_network, self.target_network, self.tau)
"Losses/Value Loss": value_loss.item(),
"Losses/Value Loss": total_value_loss.item(),
"Losses/Q1 Loss": q1_loss.item(),
"Losses/Q2 Loss": q2_loss.item(),
"Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(),

for reward_provider in self.reward_signals.values():
modules.update(reward_provider.get_modules())
return modules
def _sample_actions(
self,
vec_obs: List[torch.Tensor],
vis_obs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
all_log_probs: bool = False,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
"""
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
"""
dists, memories = self.policy.actor_critic.get_dists(
vec_obs, vis_obs, masks, memories, seq_len
)
action_list = self.policy.actor_critic.sample_action(dists)
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
action_list, dists
)
actions = torch.stack(action_list, dim=-1)
if self.policy.use_continuous_act:
actions = actions[:, :, 0]
else:
actions = actions[:, 0, :]
return (
actions,
all_logs if all_log_probs else log_probs,
entropies,
memories,
)
正在加载...
取消
保存