浏览代码

Run backwards() of losses in threads

/develop/torch-sac-threading
Ervin Teng 4 年前
当前提交
916eec4b
共有 1 个文件被更改,包括 15 次插入5 次删除
  1. 20
      ml-agents/mlagents/trainers/sac/optimizer_torch.py

20
ml-agents/mlagents/trainers/sac/optimizer_torch.py


import numpy as np
import threading
from typing import Dict, List, Mapping, cast, Tuple, Optional
from mlagents.torch_utils import torch, nn, default_device

decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
ploss_thread = self.spawn_backward_thread(policy_loss)
total_value_loss.backward()
self.value_optimizer.step()
vloss_thread = self.spawn_backward_thread(total_value_loss)
entropy_loss.backward()
entloss_thread = self.spawn_backward_thread(entropy_loss)
ploss_thread.join()
vloss_thread.join()
entloss_thread.join()
self.policy_optimizer.step()
self.value_optimizer.step()
self.entropy_optimizer.step()
# Update target network

}
return update_stats
def spawn_backward_thread(self, loss: torch.Tensor) -> threading.Thread:
thr = threading.Thread(target=loss.backward)
thr.start()
return thr
def update_reward_signals(
self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int

正在加载...
取消
保存