|
|
|
|
|
|
tensor_list: List[torch.Tensor], action_spec: ActionSpec |
|
|
|
) -> "AgentAction": |
|
|
|
continuous: torch.Tensor = None |
|
|
|
discrete: List[torch.Tensor] = None |
|
|
|
discrete: List[torch.Tensor] = None # type: ignore |
|
|
|
_offset = 0 |
|
|
|
if action_spec.continuous_size > 0: |
|
|
|
continuous = tensor_list[0] |
|
|
|
|
|
|
@staticmethod |
|
|
|
def from_dict(buff: Dict[str, np.ndarray]) -> "AgentAction": |
|
|
|
continuous: torch.Tensor = None |
|
|
|
discrete: List[torch.Tensor] = None |
|
|
|
discrete: List[torch.Tensor] = None # type: ignore |
|
|
|
discrete_tensor = ModelUtils.list_to_tensor(buff["discrete_action"]) |
|
|
|
discrete_tensor = ModelUtils.list_to_tensor( |
|
|
|
buff["discrete_action"], dtype=torch.long |
|
|
|
) |
|
|
|
discrete = [ |
|
|
|
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) |
|
|
|
] |
|
|
|
|
|
|
@property |
|
|
|
def discrete_tensor(self): |
|
|
|
return torch.stack(self.discrete_list, dim=-1) |
|
|
|
return torch.cat([_disc.unsqueeze(-1) for _disc in self.discrete_list], dim=1) |
|
|
|
|
|
|
|
@property |
|
|
|
def all_discrete_tensor(self): |
|
|
|
|
|
|
all_log_prob_list: List[torch.Tensor] = None, |
|
|
|
) -> "ActionLogProbs": |
|
|
|
continuous: torch.Tensor = None |
|
|
|
discrete: List[torch.Tensor] = None |
|
|
|
discrete: List[torch.Tensor] = None # type: ignore |
|
|
|
_offset = 0 |
|
|
|
if action_spec.continuous_size > 0: |
|
|
|
continuous = log_prob_list[0] |
|
|
|
|
|
|
@staticmethod |
|
|
|
def from_dict(buff: Dict[str, np.ndarray]) -> "ActionLogProbs": |
|
|
|
continuous: torch.Tensor = None |
|
|
|
discrete: List[torch.Tensor] = None |
|
|
|
discrete: List[torch.Tensor] = None # type: ignore |
|
|
|
if "continuous_log_probs" in buff: |
|
|
|
continuous = ModelUtils.list_to_tensor(buff["continuous_log_probs"]) |
|
|
|
if "discrete_log_probs" in buff: |
|
|
|