|
|
|
|
|
|
discrete_list: Optional[List[torch.Tensor]] |
|
|
|
|
|
|
|
@property |
|
|
|
def discrete_tensor(self): |
|
|
|
def discrete_tensor(self) -> torch.Tensor: |
|
|
|
return torch.stack(self.discrete_list, dim=-1) |
|
|
|
if self.discrete_list is not None and len(self.discrete_list) > 0: |
|
|
|
return torch.stack(self.discrete_list, dim=-1) |
|
|
|
else: |
|
|
|
return torch.empty(0) |
|
|
|
|
|
|
|
def to_action_tuple(self, clip: bool = False) -> ActionTuple: |
|
|
|
""" |
|
|
|
|
|
|
:param discrete_branches: List of sizes for discrete actions. |
|
|
|
:return: Tensor of flattened actions. |
|
|
|
""" |
|
|
|
discrete_oh = ModelUtils.actions_to_onehot( |
|
|
|
self.discrete_tensor, discrete_branches |
|
|
|
) |
|
|
|
discrete_oh = torch.cat(discrete_oh, dim=1) |
|
|
|
# if there are any discrete actions, create one-hot |
|
|
|
if self.discrete_list is not None and self.discrete_list: |
|
|
|
discrete_oh = ModelUtils.actions_to_onehot( |
|
|
|
self.discrete_tensor, discrete_branches |
|
|
|
) |
|
|
|
discrete_oh = torch.cat(discrete_oh, dim=1) |
|
|
|
else: |
|
|
|
discrete_oh = torch.empty(0) |
|
|
|
return torch.cat([self.continuous_tensor, discrete_oh], dim=-1) |