|
|
|
|
|
|
import numpy as np |
|
|
|
from typing import Dict, List, Mapping, cast, Tuple |
|
|
|
from typing import Dict, List, Mapping, cast, Tuple, Optional |
|
|
|
import attr |
|
|
|
|
|
|
|
from mlagents_envs.logging_util import get_logger |
|
|
|
from mlagents_envs.base_env import ActionType |
|
|
|
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
actions: torch.Tensor = None, |
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
q1_out, _ = self.q1_network(vec_inputs, vis_inputs, actions=actions) |
|
|
|
q2_out, _ = self.q2_network(vec_inputs, vis_inputs, actions=actions) |
|
|
|
q1_out, _ = self.q1_network( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
q2_out, _ = self.q2_network( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
actions=actions, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
return q1_out, q2_out |
|
|
|
|
|
|
|
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): |
|
|
|
|
|
|
for name in self.stream_names |
|
|
|
} |
|
|
|
|
|
|
|
# Critics should have 1/2 of the memory of the policy |
|
|
|
critic_memory = policy_network_settings.memory |
|
|
|
if critic_memory is not None: |
|
|
|
critic_memory = attr.evolve( |
|
|
|
critic_memory, memory_size=critic_memory.memory_size // 2 |
|
|
|
) |
|
|
|
value_network_settings = attr.evolve( |
|
|
|
policy_network_settings, memory=critic_memory |
|
|
|
) |
|
|
|
|
|
|
|
policy_network_settings, |
|
|
|
value_network_settings, |
|
|
|
|
|
|
|
policy_network_settings, |
|
|
|
value_network_settings, |
|
|
|
) |
|
|
|
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0) |
|
|
|
|
|
|
|
|
|
|
v_backup = min_policy_qs[name] - torch.sum( |
|
|
|
_ent_coef * log_probs, dim=1 |
|
|
|
) |
|
|
|
# print(log_probs, v_backup, _ent_coef, loss_masks) |
|
|
|
value_loss = 0.5 * torch.mean( |
|
|
|
loss_masks * torch.nn.functional.mse_loss(values[name], v_backup) |
|
|
|
) |
|
|
|
|
|
|
else: |
|
|
|
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) |
|
|
|
|
|
|
|
memories = [ |
|
|
|
memories_list = [ |
|
|
|
if len(memories) > 0: |
|
|
|
memories = torch.stack(memories).unsqueeze(0) |
|
|
|
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. |
|
|
|
offset = 1 if self.policy.sequence_length > 1 else 0 |
|
|
|
next_memories_list = [ |
|
|
|
ModelUtils.list_to_tensor( |
|
|
|
batch["memory"][i][: self.policy.m_size // 2] |
|
|
|
) # only pass value part of memory to target network |
|
|
|
for i in range(offset, len(batch["memory"]), self.policy.sequence_length) |
|
|
|
] |
|
|
|
|
|
|
|
if len(memories_list) > 0: |
|
|
|
memories = torch.stack(memories_list).unsqueeze(0) |
|
|
|
next_memories = torch.stack(next_memories_list).unsqueeze(0) |
|
|
|
else: |
|
|
|
memories = None |
|
|
|
next_memories = None |
|
|
|
# Q network memories are 0'ed out, since we don't have them during inference. |
|
|
|
q_memories = torch.zeros( |
|
|
|
(memories.shape[0], memories.shape[1], memories.shape[2] // 2) |
|
|
|
) |
|
|
|
|
|
|
|
vis_obs: List[torch.Tensor] = [] |
|
|
|
next_vis_obs: List[torch.Tensor] = [] |
|
|
|
if self.policy.use_vis_obs: |
|
|
|
|
|
|
) |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
squeezed_actions = actions.squeeze(-1) |
|
|
|
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs, sampled_actions) |
|
|
|
q1_out, q2_out = self.value_network(vec_obs, vis_obs, squeezed_actions) |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
sampled_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
squeezed_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1p_out, q2p_out = self.value_network(vec_obs, vis_obs) |
|
|
|
q1_out, q2_out = self.value_network(vec_obs, vis_obs) |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
target_values, _ = self.target_network(next_vec_obs, next_vis_obs) |
|
|
|
target_values, _ = self.target_network( |
|
|
|
next_vec_obs, |
|
|
|
next_vis_obs, |
|
|
|
memories=next_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.int32) |
|
|
|
use_discrete = not self.policy.use_continuous_act |
|
|
|
dones = ModelUtils.list_to_tensor(batch["done"]) |
|
|
|