|
|
|
|
|
|
from typing import Any, Dict, List, Tuple |
|
|
|
from typing import Any, Dict, List, Tuple, Optional |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
@timed |
|
|
|
def sample_actions( |
|
|
|
self, |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
masks=None, |
|
|
|
memories=None, |
|
|
|
seq_len=1, |
|
|
|
all_log_probs=False, |
|
|
|
): |
|
|
|
vec_obs: List[torch.Tensor], |
|
|
|
vis_obs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
seq_len: int = 1, |
|
|
|
all_log_probs: bool = False, |
|
|
|
) -> Tuple[ |
|
|
|
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|
|
|
]: |
|
|
|
""" |
|
|
|
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action. |
|
|
|
""" |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
def evaluate_actions( |
|
|
|
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1 |
|
|
|
): |
|
|
|
self, |
|
|
|
vec_obs: torch.Tensor, |
|
|
|
vis_obs: torch.Tensor, |
|
|
|
actions: torch.Tensor, |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
seq_len: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
|
dists, value_heads, _ = self.actor_critic.get_dist_and_value( |
|
|
|
vec_obs, vis_obs, masks, memories, seq_len |
|
|
|
) |
|
|
|