|
|
|
|
|
|
import numpy as np |
|
|
|
from typing import Dict, List, Mapping, cast, Tuple |
|
|
|
from typing import Dict, List, Mapping, cast, Tuple, Optional |
|
|
|
import attr |
|
|
|
|
|
|
|
from mlagents_envs.logging_util import get_logger |
|
|
|
from mlagents_envs.base_env import ActionType |
|
|
|
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
actions: torch.Tensor = None, |
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
q1_out, _ = self.q1_network(vec_inputs, vis_inputs, actions=actions) |
|
|
|
q2_out, _ = self.q2_network(vec_inputs, vis_inputs, actions=actions) |
|
|
|
q1_out, _ = self.q1_network( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
q2_out, _ = self.q2_network( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
return q1_out, q2_out |
|
|
|
|
|
|
|
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): |
|
|
|
|
|
|
for name in self.stream_names |
|
|
|
} |
|
|
|
|
|
|
|
# Critics should have 1/2 of the memory of the policy |
|
|
|
critic_memory = policy_network_settings.memory |
|
|
|
if critic_memory is not None: |
|
|
|
critic_memory = attr.evolve( |
|
|
|
critic_memory, memory_size=critic_memory.memory_size // 2 |
|
|
|
) |
|
|
|
value_network_settings = attr.evolve( |
|
|
|
policy_network_settings, memory=critic_memory |
|
|
|
) |
|
|
|
|
|
|
|
policy_network_settings, |
|
|
|
value_network_settings, |
|
|
|
|
|
|
|
policy_network_settings, |
|
|
|
value_network_settings, |
|
|
|
) |
|
|
|
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0) |
|
|
|
|
|
|
|
|
|
|
* self.gammas[i] |
|
|
|
* target_values[name] |
|
|
|
) |
|
|
|
_q1_loss = 0.5 * torch.mean( |
|
|
|
loss_masks * torch.nn.functional.mse_loss(q_backup, q1_stream) |
|
|
|
_q1_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks |
|
|
|
_q2_loss = 0.5 * torch.mean( |
|
|
|
loss_masks * torch.nn.functional.mse_loss(q_backup, q2_stream) |
|
|
|
_q2_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks |
|
|
|
) |
|
|
|
|
|
|
|
q1_losses.append(_q1_loss) |
|
|
|
|
|
|
v_backup = min_policy_qs[name] - torch.sum( |
|
|
|
_ent_coef * log_probs, dim=1 |
|
|
|
) |
|
|
|
# print(log_probs, v_backup, _ent_coef, loss_masks) |
|
|
|
value_loss = 0.5 * torch.mean( |
|
|
|
loss_masks * torch.nn.functional.mse_loss(values[name], v_backup) |
|
|
|
value_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks |
|
|
|
) |
|
|
|
value_losses.append(value_loss) |
|
|
|
else: |
|
|
|
|
|
|
v_backup = min_policy_qs[name] - torch.mean( |
|
|
|
branched_ent_bonus, axis=0 |
|
|
|
) |
|
|
|
value_loss = 0.5 * torch.mean( |
|
|
|
loss_masks |
|
|
|
* torch.nn.functional.mse_loss(values[name], v_backup.squeeze()) |
|
|
|
value_loss = 0.5 * ModelUtils.masked_mean( |
|
|
|
torch.pan><span class="n">nn.functional.mse_loss(values[namen><span class="p">], v_backup.squeeze()), |
|
|
|
loss_masks, |
|
|
|
) |
|
|
|
value_losses.append(value_loss) |
|
|
|
value_loss = torch.mean(torch.stack(value_losses)) |
|
|
|
|
|
|
if not discrete: |
|
|
|
mean_q1 = mean_q1.unsqueeze(1) |
|
|
|
batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1) |
|
|
|
policy_loss = torch.mean(loss_masks * batch_policy_loss) |
|
|
|
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) |
|
|
|
else: |
|
|
|
action_probs = log_probs.exp() |
|
|
|
branched_per_action_ent = ModelUtils.break_into_branches( |
|
|
|
|
|
|
target_current_diff = torch.squeeze( |
|
|
|
target_current_diff_branched, axis=2 |
|
|
|
) |
|
|
|
entropy_loss = -torch.mean( |
|
|
|
loss_masks |
|
|
|
* torch.mean(self._log_ent_coef * target_current_diff, axis=1) |
|
|
|
entropy_loss = -1 * ModelUtils.masked_mean( |
|
|
|
torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks |
|
|
|
) |
|
|
|
|
|
|
|
return entropy_loss |
|
|
|
|
|
|
else: |
|
|
|
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) |
|
|
|
|
|
|
|
memories = [ |
|
|
|
memories_list = [ |
|
|
|
if len(memories) > 0: |
|
|
|
memories = torch.stack(memories).unsqueeze(0) |
|
|
|
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. |
|
|
|
offset = 1 if self.policy.sequence_length > 1 else 0 |
|
|
|
next_memories_list = [ |
|
|
|
ModelUtils.list_to_tensor( |
|
|
|
batch["memory"][i][self.policy.m_size // 2 :] |
|
|
|
) # only pass value part of memory to target network |
|
|
|
for i in range(offset, len(batch["memory"]), self.policy.sequence_length) |
|
|
|
] |
|
|
|
|
|
|
|
if len(memories_list) > 0: |
|
|
|
memories = torch.stack(memories_list).unsqueeze(0) |
|
|
|
next_memories = torch.stack(next_memories_list).unsqueeze(0) |
|
|
|
else: |
|
|
|
memories = None |
|
|
|
next_memories = None |
|
|
|
# Q network memories are 0'ed out, since we don't have them during inference. |
|
|
|
q_memories = torch.zeros_like(next_memories) |
|
|
|
|
|
|
|
vis_obs: List[torch.Tensor] = [] |
|
|
|
next_vis_obs: List[torch.Tensor] = [] |
|
|
|
if self.policy.use_vis_obs: |
|
|
|
|
|
|
) |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
squeezed_actions = actions.squeeze(-1) |
|
|
|
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs, sampled_actions) |
|
|
|
q1_out, q2_out = self.value_network(vec_obs, vis_obs, squeezed_actions) |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
sampled_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
squeezed_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs) |
|
|
|
q1_out, q2_out = self.value_network(vec_obs, vis_obs) |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
target_values, _ = self.target_network(next_vec_obs, next_vis_obs) |
|
|
|
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32) |
|
|
|
target_values, _ = self.target_network( |
|
|
|
next_vec_obs, |
|
|
|
next_vis_obs, |
|
|
|
memories=next_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|
|
|
use_discrete = not self.policy.use_continuous_act |
|
|
|
dones = ModelUtils.list_to_tensor(batch["done"]) |
|
|
|
|
|
|
|