|
|
|
|
|
|
return AgentAction(continuous, discrete) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def extract(buff: Dict[str, np.ndarray]) -> "AgentAction": |
|
|
|
def from_dict(buff: Dict[str, np.ndarray]) -> "AgentAction": |
|
|
|
continuous: torch.Tensor = None |
|
|
|
discrete: List[torch.Tensor] = None |
|
|
|
if "continuous_action" in buff: |
|
|
|
|
|
|
return ActionLogProbs(continuous, discrete, all_log_prob_list) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def extract(buff: Dict[str, np.ndarray]) -> "ActionLogProbs": |
|
|
|
def from_dict(buff: Dict[str, np.ndarray]) -> "ActionLogProbs": |
|
|
|
continuous: torch.Tensor = None |
|
|
|
discrete: List[torch.Tensor] = None |
|
|
|
if "continuous_log_probs" in buff: |
|
|
|