|
|
|
|
|
|
|
|
|
|
class ActionFlattener: |
|
|
|
def __init__(self, action_spec: ActionSpec): |
|
|
|
""" |
|
|
|
A torch module that creates the flattened form of an AgentAction object. |
|
|
|
The flattened form is the continuous action concatenated with the |
|
|
|
concatenated one hot encodings of the discrete actions. |
|
|
|
:param action_spec: An ActionSpec that describes the action space dimensions |
|
|
|
""" |
|
|
|
""" |
|
|
|
The flattened size is the continuous size plus the sum of the branch sizes |
|
|
|
since discrete actions are encoded as one hots. |
|
|
|
""" |
|
|
|
""" |
|
|
|
Returns a tensor corresponding the flattened action |
|
|
|
:param action: An AgentAction object |
|
|
|
""" |
|
|
|
action_list: List[torch.Tensor] = [] |
|
|
|
if self._specs.continuous_size > 0: |
|
|
|
action_list.append(action.continuous_tensor) |
|
|
|