|
|
|
|
|
|
import numpy as np |
|
|
|
import threading |
|
|
|
from typing import Dict, List, Mapping, cast, Tuple, Optional |
|
|
|
from typing import Dict, List, Mapping, cast, Tuple, Optional, Callable |
|
|
|
from mlagents.torch_utils import torch, nn, default_device |
|
|
|
|
|
|
|
from mlagents_envs.logging_util import get_logger |
|
|
|
|
|
|
self.target_network.network_body.copy_normalization( |
|
|
|
self.policy.actor_critic.network_body |
|
|
|
) |
|
|
|
( |
|
|
|
sampled_actions, |
|
|
|
log_probs, |
|
|
|
entropies, |
|
|
|
sampled_values, |
|
|
|
_, |
|
|
|
) = self.policy.sample_actions( |
|
|
|
results = {} |
|
|
|
threads = [] |
|
|
|
policy_thread = self.spawn_forward_thread( |
|
|
|
results, |
|
|
|
"policy", |
|
|
|
self.policy.sample_actions, |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
masks=act_masks, |
|
|
|
|
|
|
) |
|
|
|
policy_thread.join() |
|
|
|
( |
|
|
|
sampled_actions, |
|
|
|
log_probs, |
|
|
|
entropies, |
|
|
|
sampled_values, |
|
|
|
_, |
|
|
|
) = results["policy"] |
|
|
|
# ( |
|
|
|
# sampled_actions, |
|
|
|
# log_probs, |
|
|
|
# entropies, |
|
|
|
# sampled_values, |
|
|
|
# _, |
|
|
|
# ) = self.policy.sample_actions( |
|
|
|
# vec_obs, |
|
|
|
# vis_obs, |
|
|
|
# masks=act_masks, |
|
|
|
# memories=memories, |
|
|
|
# seq_len=self.policy.sequence_length, |
|
|
|
# all_log_probs=not self.policy.use_continuous_act, |
|
|
|
# ) |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
|
|
|
|
qp_thread = self.spawn_forward_thread( |
|
|
|
results, |
|
|
|
"qp", |
|
|
|
self.value_network, |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
sampled_actions, |
|
|
|
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
q_thread = self.spawn_forward_thread( |
|
|
|
results, |
|
|
|
"q", |
|
|
|
self.value_network, |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
squeezed_actions, |
|
|
|
|
|
|
q1_stream, q2_stream = q1_out, q2_out |
|
|
|
# q1_stream, q2_stream = q1_out, q2_out |
|
|
|
else: |
|
|
|
with torch.no_grad(): |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
|
|
|
q2_stream = self._condense_q_streams(q2_out, actions) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
target_values, _ = self.target_network( |
|
|
|
target_value_thread = self.spawn_forward_thread( |
|
|
|
results, |
|
|
|
"target_value", |
|
|
|
self.target_network, |
|
|
|
next_vec_obs, |
|
|
|
next_vis_obs, |
|
|
|
memories=next_memories, |
|
|
|
|
|
|
use_discrete = not self.policy.use_continuous_act |
|
|
|
dones = ModelUtils.list_to_tensor(batch["done"]) |
|
|
|
|
|
|
|
q_thread.join() |
|
|
|
q1_stream, q2_stream = results["q"] |
|
|
|
target_value_thread.join() |
|
|
|
target_values, _ = results["target_value"] |
|
|
|
|
|
|
|
qp_thread.join() |
|
|
|
q1p_out, q2p_out = results["qp"] |
|
|
|
|
|
|
|
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete) |
|
|
|
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete) |
|
|
|
|
|
|
|
|
|
|
ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) |
|
|
|
self.policy_optimizer.zero_grad() |
|
|
|
ploss_thread = self.spawn_backward_thread(policy_loss) |
|
|
|
policy_loss.backward() |
|
|
|
self.policy_optimizer.step() |
|
|
|
# ploss_thread = self.spawn_backward_thread(policy_loss) |
|
|
|
vloss_thread = self.spawn_backward_thread(total_value_loss) |
|
|
|
total_value_loss.backward() |
|
|
|
self.value_optimizer.step() |
|
|
|
# vloss_thread = self.spawn_backward_thread(total_value_loss) |
|
|
|
entloss_thread = self.spawn_backward_thread(entropy_loss) |
|
|
|
|
|
|
|
ploss_thread.join() |
|
|
|
vloss_thread.join() |
|
|
|
entloss_thread.join() |
|
|
|
self.policy_optimizer.step() |
|
|
|
self.value_optimizer.step() |
|
|
|
entropy_loss.backward() |
|
|
|
# entloss_thread = self.spawn_backward_thread(entropy_loss) |
|
|
|
|
|
|
|
# ploss_thread.join() |
|
|
|
# vloss_thread.join() |
|
|
|
# entloss_thread.join() |
|
|
|
# self.policy_optimizer.step() |
|
|
|
# self.value_optimizer.step() |
|
|
|
# self.entropy_optimizer.step() |
|
|
|
|
|
|
|
# Update target network |
|
|
|
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau) |
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
return update_stats |
|
|
|
|
|
|
|
def spawn_forward_thread( |
|
|
|
self, results: Dict[str, Tuple], name: str, func: Callable, *args, **kwargs |
|
|
|
): |
|
|
|
thr = TorchSACOptimizer.ForwardThread(results, name, func, args, kwargs) |
|
|
|
thr.start() |
|
|
|
return thr |
|
|
|
|
|
|
|
class ForwardThread(threading.Thread): |
|
|
|
def __init__( |
|
|
|
self, results: Dict[str, Tuple], name: str, func: Callable, args, kwargs |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.func = func |
|
|
|
self.func_args = args |
|
|
|
self.func_kwargs = kwargs |
|
|
|
self.name = name |
|
|
|
self.results = results |
|
|
|
|
|
|
|
def run(self): |
|
|
|
self.results[self.name] = self.func(*self.func_args, **self.func_kwargs) |
|
|
|
|
|
|
|
def _forward_thread_func( |
|
|
|
self, results: Dict[str, Tuple], name: str, func: Callable, args, kwargs |
|
|
|
): |
|
|
|
result = func(*args, **kwargs) |
|
|
|
print(name, result) |
|
|
|
results[name] = result |
|
|
|
|
|
|
|
def spawn_backward_thread(self, loss: torch.Tensor) -> threading.Thread: |
|
|
|
thr = threading.Thread(target=loss.backward) |
|
|
|