|
|
|
|
|
|
import numpy as np |
|
|
|
import threading |
|
|
|
from typing import Dict, List, Mapping, cast, Tuple, Optional |
|
|
|
from mlagents.torch_utils import torch, nn, default_device |
|
|
|
|
|
|
|
|
|
|
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|
|
|
ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) |
|
|
|
self.policy_optimizer.zero_grad() |
|
|
|
policy_loss.backward() |
|
|
|
self.policy_optimizer.step() |
|
|
|
ploss_thread = self.spawn_backward_thread(policy_loss) |
|
|
|
total_value_loss.backward() |
|
|
|
self.value_optimizer.step() |
|
|
|
vloss_thread = self.spawn_backward_thread(total_value_loss) |
|
|
|
entropy_loss.backward() |
|
|
|
entloss_thread = self.spawn_backward_thread(entropy_loss) |
|
|
|
|
|
|
|
ploss_thread.join() |
|
|
|
vloss_thread.join() |
|
|
|
entloss_thread.join() |
|
|
|
self.policy_optimizer.step() |
|
|
|
self.value_optimizer.step() |
|
|
|
self.entropy_optimizer.step() |
|
|
|
|
|
|
|
# Update target network |
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
return update_stats |
|
|
|
|
|
|
|
def spawn_backward_thread(self, loss: torch.Tensor) -> threading.Thread: |
|
|
|
thr = threading.Thread(target=loss.backward) |
|
|
|
thr.start() |
|
|
|
return thr |
|
|
|
|
|
|
|
def update_reward_signals( |
|
|
|
self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int |
|
|
|