|
|
|
|
|
|
continuous_tensor: torch.Tensor |
|
|
|
discrete_list: Optional[List[torch.Tensor]] |
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
if isinstance(index, slice): |
|
|
|
_cont = None |
|
|
|
_disc_list = [] |
|
|
|
if self.continuous_tensor is not None: |
|
|
|
_cont = self.continuous_tensor.__getitem__(index) |
|
|
|
if self.discrete_list is not None and len(self.discrete_list) > 0: |
|
|
|
for _disc in self.discrete_list: |
|
|
|
_disc_list.append(_disc.__getitem__(index)) |
|
|
|
return AgentAction(_cont, _disc_list) |
|
|
|
else: |
|
|
|
return super().__getitem__(index) |
|
|
|
|
|
|
|
@property |
|
|
|
def discrete_tensor(self) -> torch.Tensor: |
|
|
|
""" |
|
|
|
|
|
|
return torch.stack(self.discrete_list, dim=-1) |
|
|
|
else: |
|
|
|
return torch.empty(0) |
|
|
|
|
|
|
|
def slice(self, start: int, end: int) -> "AgentAction": |
|
|
|
""" |
|
|
|
Returns an AgentAction with the continuous and discrete tensors slices |
|
|
|
from index start to index end. |
|
|
|
""" |
|
|
|
_cont = None |
|
|
|
_disc_list = [] |
|
|
|
if self.continuous_tensor is not None: |
|
|
|
_cont = self.continuous_tensor[start:end] |
|
|
|
if self.discrete_list is not None and len(self.discrete_list) > 0: |
|
|
|
for _disc in self.discrete_list: |
|
|
|
_disc_list.append(_disc[start:end]) |
|
|
|
return AgentAction(_cont, _disc_list) |
|
|
|
|
|
|
|
def to_action_tuple(self, clip: bool = False) -> ActionTuple: |
|
|
|
""" |
|
|
|