|
|
|
|
|
|
import numpy as np |
|
|
|
from typing import Dict, List, Mapping, cast, Tuple, Optional |
|
|
|
from mlagents.torch_utils import torch, nn, default_device |
|
|
|
|
|
|
|
import time |
|
|
|
from mlagents_envs.logging_util import get_logger |
|
|
|
from mlagents_envs.base_env import ActionType |
|
|
|
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: |
|
|
|
q1_out, _ = self.q1_network( |
|
|
|
q1_fut = torch.jit._fork( |
|
|
|
self.q1_network, |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
actions=actions, |
|
|
|
|
|
|
q2_out, _ = self.q2_network( |
|
|
|
q2_fut = torch.jit._fork( |
|
|
|
self.q2_network, |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
actions=actions, |
|
|
|
|
|
|
q1_out, _ = torch.jit._wait(q1_fut) |
|
|
|
q2_out, _ = torch.jit._wait(q2_fut) |
|
|
|
return q1_out, q2_out |
|
|
|
|
|
|
|
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): |
|
|
|
|
|
|
policy_network_settings, |
|
|
|
self.policy.behavior_spec.action_type, |
|
|
|
self.act_size, |
|
|
|
) |
|
|
|
( |
|
|
|
dummy_vec_obs, |
|
|
|
dummy_vis_obs, |
|
|
|
dummy_masks, |
|
|
|
dummy_memories, |
|
|
|
) = ModelUtils.create_dummy_input(self.policy) |
|
|
|
example_inputs = ( |
|
|
|
dummy_vec_obs, |
|
|
|
dummy_vis_obs, |
|
|
|
torch.zeros((1, sum(self.act_size))) |
|
|
|
if self.policy.use_continuous_act |
|
|
|
else None, |
|
|
|
dummy_memories, |
|
|
|
torch.tensor(self.policy.sequence_length), |
|
|
|
) |
|
|
|
self.value_network = torch.jit.trace( |
|
|
|
self.value_network, example_inputs, strict=False |
|
|
|
) |
|
|
|
|
|
|
|
self.target_network = ValueNetwork( |
|
|
|
|
|
|
indexed by name. If none, don't update the reward signals. |
|
|
|
:return: Output from update process. |
|
|
|
""" |
|
|
|
t0 = time.time() |
|
|
|
rewards = {} |
|
|
|
for name in self.reward_signals: |
|
|
|
rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) |
|
|
|
|
|
|
next_memories = None |
|
|
|
# Q network memories are 0'ed out, since we don't have them during inference. |
|
|
|
q_memories = ( |
|
|
|
torch.zeros_like(next_memories) if next_memories is not None else None |
|
|
|
torch.zeros_like(next_memories) |
|
|
|
if next_memories is not None |
|
|
|
else torch.empty((1, 1, 0)) |
|
|
|
) |
|
|
|
|
|
|
|
vis_obs: List[torch.Tensor] = [] |
|
|
|
|
|
|
next_vis_obs.append(next_vis_ob) |
|
|
|
|
|
|
|
# Copy normalizers from policy |
|
|
|
self.value_network.q1_network.network_body.copy_normalization( |
|
|
|
self.policy.actor_critic.network_body |
|
|
|
) |
|
|
|
self.value_network.q2_network.network_body.copy_normalization( |
|
|
|
self.policy.actor_critic.network_body |
|
|
|
) |
|
|
|
# self.value_network.q1_network.network_body.copy_normalization( |
|
|
|
# self.policy.actor_critic.network_body |
|
|
|
# ) |
|
|
|
# self.value_network.q2_network.network_body.copy_normalization( |
|
|
|
# self.policy.actor_critic.network_body |
|
|
|
# ) |
|
|
|
t1 = time.time() |
|
|
|
( |
|
|
|
sampled_actions, |
|
|
|
log_probs, |
|
|
|
|
|
|
seq_len=self.policy.sequence_length, |
|
|
|
all_log_probs=not self.policy.use_continuous_act, |
|
|
|
) |
|
|
|
t2 = time.time() |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
# q1p_out, q2p_out = self.value_network( |
|
|
|
# vec_obs, |
|
|
|
# torch.tensor(vis_obs), |
|
|
|
# sampled_actions, |
|
|
|
# q_memories, |
|
|
|
# torch.tensor(self.policy.sequence_length), |
|
|
|
# ) |
|
|
|
qp_fut = torch.jit._fork( |
|
|
|
self.value_network, |
|
|
|
vis_obs, |
|
|
|
torch.tensor(vis_obs), |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
q_memories, |
|
|
|
torch.tensor(self.policy.sequence_length), |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
q_fut = torch.jit._fork( |
|
|
|
self.value_network, |
|
|
|
vis_obs, |
|
|
|
torch.tensor(vis_obs), |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
q_memories, |
|
|
|
torch.tensor(self.policy.sequence_length), |
|
|
|
# q1_out, q2_out = self.value_network( |
|
|
|
# vec_obs, |
|
|
|
# torch.tensor(vis_obs), |
|
|
|
# squeezed_actions, |
|
|
|
# q_memories, |
|
|
|
# torch.tensor(self.policy.sequence_length), |
|
|
|
# ) |
|
|
|
q1p_out, q2p_out = torch.jit._wait(qp_fut) |
|
|
|
q1_out, q2_out = torch.jit._wait(q_fut) |
|
|
|
q1_stream, q2_stream = q1_out, q2_out |
|
|
|
else: |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
) |
|
|
|
q1_stream = self._condense_q_streams(q1_out, actions) |
|
|
|
q2_stream = self._condense_q_streams(q2_out, actions) |
|
|
|
|
|
|
|
t3 = time.time() |
|
|
|
with torch.no_grad(): |
|
|
|
target_values, _ = self.target_network( |
|
|
|
next_vec_obs, |
|
|
|
|
|
|
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|
|
|
use_discrete = not self.policy.use_continuous_act |
|
|
|
dones = ModelUtils.list_to_tensor(batch["done"]) |
|
|
|
|
|
|
|
t4 = time.time() |
|
|
|
q1_loss, q2_loss = self.sac_q_loss( |
|
|
|
q1_stream, q2_stream, target_values, dones, rewards, masks |
|
|
|
) |
|
|
|
|
|
|
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete) |
|
|
|
|
|
|
|
total_value_loss = q1_loss + q2_loss + value_loss |
|
|
|
|
|
|
|
t5 = time.time() |
|
|
|
|
|
|
|
t6 = time.time() |
|
|
|
|
|
|
|
t7 = time.time() |
|
|
|
|
|
|
|
t8 = time.time() |
|
|
|
# Update target network |
|
|
|
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau) |
|
|
|
update_stats = { |
|
|
|
|
|
|
"Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(), |
|
|
|
"Policy/Learning Rate": decay_lr, |
|
|
|
} |
|
|
|
|
|
|
|
t9 = time.time() |
|
|
|
print( |
|
|
|
t9 - t8, |
|
|
|
t8 - t7, |
|
|
|
t7 - t6, |
|
|
|
t6 - t5, |
|
|
|
t5 - t4, |
|
|
|
t4 - t3, |
|
|
|
t3 - t2, |
|
|
|
t2 - t1, |
|
|
|
t1 - t0, |
|
|
|
) |
|
|
|
return update_stats |
|
|
|
|
|
|
|
def update_reward_signals( |
|
|
|