浏览代码

Try futures in Optimizer

/develop/jit/experiments
Ervin Teng 4 年前
当前提交
228ea059
共有 1 个文件被更改,包括 21 次插入21 次删除
  1. 42
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

42
ml-agents/mlagents/trainers/sac/optimizer_torch.py


indexed by name. If none, don't update the reward signals.
:return: Output from update process.
"""
t0 = time.time()
# t0 = time.time()
rewards = {}
for name in self.reward_signals:
rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"])

self.target_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
t1 = time.time()
# t1 = time.time()
(
sampled_actions,
log_probs,

seq_len=self.policy.sequence_length,
all_log_probs=not self.policy.use_continuous_act,
)
t2 = time.time()
# t2 = time.time()
if self.policy.use_continuous_act:
squeezed_actions = actions.squeeze(-1)
# q1p_out, q2p_out = self.value_network(

)
q1_stream = self._condense_q_streams(q1_out, actions)
q2_stream = self._condense_q_streams(q2_out, actions)
t3 = time.time()
# t3 = time.time()
with torch.no_grad():
target_values, _ = self.target_network(
next_vec_obs,

masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
use_discrete = not self.policy.use_continuous_act
dones = ModelUtils.list_to_tensor(batch["done"])
t4 = time.time()
# t4 = time.time()
q1_loss, q2_loss = self.sac_q_loss(
q1_stream, q2_stream, target_values, dones, rewards, masks
)

entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete)
total_value_loss = q1_loss + q2_loss + value_loss
t5 = time.time()
# t5 = time.time()
t6 = time.time()
# t6 = time.time()
t7 = time.time()
# t7 = time.time()
t8 = time.time()
# t8 = time.time()
# Update target network
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau)
update_stats = {

"Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(),
"Policy/Learning Rate": decay_lr,
}
t9 = time.time()
print(
t9 - t8,
t8 - t7,
t7 - t6,
t6 - t5,
t5 - t4,
t4 - t3,
t3 - t2,
t2 - t1,
t1 - t0,
)
# t9 = time.time()
# print(
# t9 - t8,
# t8 - t7,
# t7 - t6,
# t6 - t5,
# t5 - t4,
# t4 - t3,
# t3 - t2,
# t2 - t1,
# t1 - t0,
# )
return update_stats
def update_reward_signals(

正在加载...
取消
保存