浏览代码

Try futures in Optimizer

/develop/jit/experiments
Ervin Teng 4 年前
当前提交
a305a41b
共有 1 个文件被更改,包括 81 次插入25 次删除
  1. 106
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

106
ml-agents/mlagents/trainers/sac/optimizer_torch.py


import numpy as np
from typing import Dict, List, Mapping, cast, Tuple, Optional
from mlagents.torch_utils import torch, nn, default_device
import time
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import ActionType
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
q1_out, _ = self.q1_network(
q1_fut = torch.jit._fork(
self.q1_network,
vec_inputs,
vis_inputs,
actions=actions,

q2_out, _ = self.q2_network(
q2_fut = torch.jit._fork(
self.q2_network,
vec_inputs,
vis_inputs,
actions=actions,

q1_out, _ = torch.jit._wait(q1_fut)
q2_out, _ = torch.jit._wait(q2_fut)
return q1_out, q2_out
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):

policy_network_settings,
self.policy.behavior_spec.action_type,
self.act_size,
)
(
dummy_vec_obs,
dummy_vis_obs,
dummy_masks,
dummy_memories,
) = ModelUtils.create_dummy_input(self.policy)
example_inputs = (
dummy_vec_obs,
dummy_vis_obs,
torch.zeros((1, sum(self.act_size)))
if self.policy.use_continuous_act
else None,
dummy_memories,
torch.tensor(self.policy.sequence_length),
)
self.value_network = torch.jit.trace(
self.value_network, example_inputs, strict=False
)
self.target_network = ValueNetwork(

indexed by name. If none, don't update the reward signals.
:return: Output from update process.
"""
t0 = time.time()
rewards = {}
for name in self.reward_signals:
rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"])

next_memories = None
# Q network memories are 0'ed out, since we don't have them during inference.
q_memories = (
torch.zeros_like(next_memories) if next_memories is not None else None
torch.zeros_like(next_memories)
if next_memories is not None
else torch.empty((1, 1, 0))
)
vis_obs: List[torch.Tensor] = []

next_vis_obs.append(next_vis_ob)
# Copy normalizers from policy
self.value_network.q1_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
self.value_network.q2_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
# self.value_network.q1_network.network_body.copy_normalization(
# self.policy.actor_critic.network_body
# )
# self.value_network.q2_network.network_body.copy_normalization(
# self.policy.actor_critic.network_body
# )
t1 = time.time()
(
sampled_actions,
log_probs,

seq_len=self.policy.sequence_length,
all_log_probs=not self.policy.use_continuous_act,
)
t2 = time.time()
q1p_out, q2p_out = self.value_network(
# q1p_out, q2p_out = self.value_network(
# vec_obs,
# torch.tensor(vis_obs),
# sampled_actions,
# q_memories,
# torch.tensor(self.policy.sequence_length),
# )
qp_fut = torch.jit._fork(
self.value_network,
vis_obs,
torch.tensor(vis_obs),
memories=q_memories,
sequence_length=self.policy.sequence_length,
q_memories,
torch.tensor(self.policy.sequence_length),
q1_out, q2_out = self.value_network(
q_fut = torch.jit._fork(
self.value_network,
vis_obs,
torch.tensor(vis_obs),
memories=q_memories,
sequence_length=self.policy.sequence_length,
q_memories,
torch.tensor(self.policy.sequence_length),
# q1_out, q2_out = self.value_network(
# vec_obs,
# torch.tensor(vis_obs),
# squeezed_actions,
# q_memories,
# torch.tensor(self.policy.sequence_length),
# )
q1p_out, q2p_out = torch.jit._wait(qp_fut)
q1_out, q2_out = torch.jit._wait(q_fut)
q1_stream, q2_stream = q1_out, q2_out
else:
with torch.no_grad():

)
q1_stream = self._condense_q_streams(q1_out, actions)
q2_stream = self._condense_q_streams(q2_out, actions)
t3 = time.time()
with torch.no_grad():
target_values, _ = self.target_network(
next_vec_obs,

masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
use_discrete = not self.policy.use_continuous_act
dones = ModelUtils.list_to_tensor(batch["done"])
t4 = time.time()
q1_loss, q2_loss = self.sac_q_loss(
q1_stream, q2_stream, target_values, dones, rewards, masks
)

entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete)
total_value_loss = q1_loss + q2_loss + value_loss
t5 = time.time()
t6 = time.time()
t7 = time.time()
t8 = time.time()
# Update target network
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau)
update_stats = {

"Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(),
"Policy/Learning Rate": decay_lr,
}
t9 = time.time()
print(
t9 - t8,
t8 - t7,
t7 - t6,
t6 - t5,
t5 - t4,
t4 - t3,
t3 - t2,
t2 - t1,
t1 - t0,
)
return update_stats
def update_reward_signals(

正在加载...
取消
保存