|
|
|
|
|
|
def _group_from_buffer( |
|
|
|
buff: AgentBuffer, cont_action_key: BufferKey, disc_action_key: BufferKey |
|
|
|
) -> List["AgentAction"]: |
|
|
|
""" |
|
|
|
Extracts continuous and discrete groupmate actions, as specified by BufferKey, and |
|
|
|
returns a List of AgentActions that correspond to the groupmate's actions. List will |
|
|
|
be of length equal to the maximum number of groupmates in the buffer. Any spots where |
|
|
|
there are less agents than maximum, the actions will be padded with 0's. |
|
|
|
""" |
|
|
|
continuous_tensors: List[torch.Tensor] = [] |
|
|
|
discrete_tensors: List[torch.Tensor] = [] |
|
|
|
if cont_action_key in buff: |
|
|
|
|
|
|
@staticmethod |
|
|
|
def group_from_buffer(buff: AgentBuffer) -> List["AgentAction"]: |
|
|
|
""" |
|
|
|
A static method that accesses continuous and discrete action fields in an AgentBuffer |
|
|
|
and constructs the corresponding AgentAction from the retrieved np arrays. |
|
|
|
A static method that accesses next group continuous and discrete action fields in an AgentBuffer |
|
|
|
and constructs a padded List of AgentActions that represent the group agent actions. |
|
|
|
The List is of length equal to max number of groupmate agents in the buffer, and the AgentBuffer iss |
|
|
|
of the same length as the buffer. Empty spots (e.g. when agents die) are padded with 0. |
|
|
|
:param buff: AgentBuffer of a batch or trajectory |
|
|
|
:return: List of groupmate's AgentActions |
|
|
|
""" |
|
|
|
return AgentAction._group_from_buffer( |
|
|
|
buff, BufferKey.GROUP_CONTINUOUS_ACTION, BufferKey.GROUP_DISCRETE_ACTION |
|
|
|
|
|
|
def group_from_buffer_next(buff: AgentBuffer) -> List["AgentAction"]: |
|
|
|
""" |
|
|
|
A static method that accesses next continuous and discrete action fields in an AgentBuffer |
|
|
|
and constructs the corresponding AgentAction from the retrieved np arrays. |
|
|
|
A static method that accesses next group continuous and discrete action fields in an AgentBuffer |
|
|
|
and constructs a padded List of AgentActions that represent the next group agent actions. |
|
|
|
The List is of length equal to max number of groupmate agents in the buffer, and the AgentBuffer iss |
|
|
|
of the same length as the buffer. Empty spots (e.g. when agents die) are padded with 0. |
|
|
|
:param buff: AgentBuffer of a batch or trajectory |
|
|
|
:return: List of groupmate's AgentActions |
|
|
|
""" |
|
|
|
return AgentAction._group_from_buffer( |
|
|
|
buff, BufferKey.GROUP_NEXT_CONT_ACTION, BufferKey.GROUP_NEXT_DISC_ACTION |
|
|
|
|
|
|
""" |
|
|
|
Flatten this AgentAction into a single torch Tensor of dimension (batch, num_continuous + num_one_hot_discrete). |
|
|
|
Discrete actions are converted into one-hot and concatenated with continuous actions. |
|
|
|
:param discrete_branches: List of sizes for discrete actions. |
|
|
|
:return: Tensor of flattened actions. |
|
|
|
""" |
|
|
|
discrete_oh = ModelUtils.actions_to_onehot( |
|
|
|
self.discrete_tensor, discrete_branches |
|
|
|
) |