浏览代码

Thread inference and not backprop

/develop/torch-sac-threading
Ervin Teng 4 年前
当前提交
9b797d61
共有 1 个文件被更改,包括 94 次插入21 次删除
  1. 115
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

115
ml-agents/mlagents/trainers/sac/optimizer_torch.py


import numpy as np
import threading
from typing import Dict, List, Mapping, cast, Tuple, Optional
from typing import Dict, List, Mapping, cast, Tuple, Optional, Callable
from mlagents.torch_utils import torch, nn, default_device
from mlagents_envs.logging_util import get_logger

self.target_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
(
sampled_actions,
log_probs,
entropies,
sampled_values,
_,
) = self.policy.sample_actions(
results = {}
threads = []
policy_thread = self.spawn_forward_thread(
results,
"policy",
self.policy.sample_actions,
vec_obs,
vis_obs,
masks=act_masks,

)
policy_thread.join()
(
sampled_actions,
log_probs,
entropies,
sampled_values,
_,
) = results["policy"]
# (
# sampled_actions,
# log_probs,
# entropies,
# sampled_values,
# _,
# ) = self.policy.sample_actions(
# vec_obs,
# vis_obs,
# masks=act_masks,
# memories=memories,
# seq_len=self.policy.sequence_length,
# all_log_probs=not self.policy.use_continuous_act,
# )
q1p_out, q2p_out = self.value_network(
qp_thread = self.spawn_forward_thread(
results,
"qp",
self.value_network,
vec_obs,
vis_obs,
sampled_actions,

q1_out, q2_out = self.value_network(
q_thread = self.spawn_forward_thread(
results,
"q",
self.value_network,
vec_obs,
vis_obs,
squeezed_actions,

q1_stream, q2_stream = q1_out, q2_out
# q1_stream, q2_stream = q1_out, q2_out
else:
with torch.no_grad():
q1p_out, q2p_out = self.value_network(

q2_stream = self._condense_q_streams(q2_out, actions)
with torch.no_grad():
target_values, _ = self.target_network(
target_value_thread = self.spawn_forward_thread(
results,
"target_value",
self.target_network,
next_vec_obs,
next_vis_obs,
memories=next_memories,

use_discrete = not self.policy.use_continuous_act
dones = ModelUtils.list_to_tensor(batch["done"])
q_thread.join()
q1_stream, q2_stream = results["q"]
target_value_thread.join()
target_values, _ = results["target_value"]
qp_thread.join()
q1p_out, q2p_out = results["qp"]
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete)
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete)

ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
self.policy_optimizer.zero_grad()
ploss_thread = self.spawn_backward_thread(policy_loss)
policy_loss.backward()
self.policy_optimizer.step()
# ploss_thread = self.spawn_backward_thread(policy_loss)
vloss_thread = self.spawn_backward_thread(total_value_loss)
total_value_loss.backward()
self.value_optimizer.step()
# vloss_thread = self.spawn_backward_thread(total_value_loss)
entloss_thread = self.spawn_backward_thread(entropy_loss)
ploss_thread.join()
vloss_thread.join()
entloss_thread.join()
self.policy_optimizer.step()
self.value_optimizer.step()
entropy_loss.backward()
# entloss_thread = self.spawn_backward_thread(entropy_loss)
# ploss_thread.join()
# vloss_thread.join()
# entloss_thread.join()
# self.policy_optimizer.step()
# self.value_optimizer.step()
# self.entropy_optimizer.step()
# Update target network
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau)

}
return update_stats
def spawn_forward_thread(
self, results: Dict[str, Tuple], name: str, func: Callable, *args, **kwargs
):
thr = TorchSACOptimizer.ForwardThread(results, name, func, args, kwargs)
thr.start()
return thr
class ForwardThread(threading.Thread):
def __init__(
self, results: Dict[str, Tuple], name: str, func: Callable, args, kwargs
):
super().__init__()
self.func = func
self.func_args = args
self.func_kwargs = kwargs
self.name = name
self.results = results
def run(self):
self.results[self.name] = self.func(*self.func_args, **self.func_kwargs)
def _forward_thread_func(
self, results: Dict[str, Tuple], name: str, func: Callable, args, kwargs
):
result = func(*args, **kwargs)
print(name, result)
results[name] = result
def spawn_backward_thread(self, loss: torch.Tensor) -> threading.Thread:
thr = threading.Thread(target=loss.backward)

正在加载...
取消
保存