|
|
|
|
|
|
from typing import Dict, Any |
|
|
|
from typing import Dict |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
self.has_updated = False |
|
|
|
self.use_recurrent = self.policy.use_recurrent |
|
|
|
self.samples_per_update = settings.samples_per_update |
|
|
|
# self.out_dict = { |
|
|
|
# "loss": self.model.loss, |
|
|
|
# "update": self.model.update_batch, |
|
|
|
# "learning_rate": self.model.annealed_learning_rate, |
|
|
|
# } |
|
|
|
def update(self) -> Dict[str, Any]: |
|
|
|
def update(self) -> Dict[str, np.ndarray]: |
|
|
|
""" |
|
|
|
Updates model using buffer. |
|
|
|
:param max_batches: The maximum number of batches to use per update. |
|
|
|
|
|
|
return bc_loss |
|
|
|
|
|
|
|
def _update_batch( |
|
|
|
self, mini_batch_demo: Dict[str, Any], n_sequences: int |
|
|
|
) -> Dict[str, Any]: |
|
|
|
self, mini_batch_demo: Dict[str, np.ndarray], n_sequences: int |
|
|
|
) -> Dict[str, float]: |
|
|
|
""" |
|
|
|
Helper function for update_batch. |
|
|
|
""" |
|
|
|
|
|
|
else: |
|
|
|
vis_obs = [] |
|
|
|
|
|
|
|
selected_actions, all_log_probs, entropies, values, memories = self.policy.sample_actions( |
|
|
|
selected_actions, all_log_probs, _, _, _ = self.policy.sample_actions( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
masks=act_masks, |
|
|
|