|
|
|
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encodes = [] |
|
|
|
var_len_inputs = [] # The list of variable length inputs |
|
|
|
var_len_processors = [ |
|
|
|
p for p in self.processors if isinstance(p, EntityEmbedding) |
|
|
|
] |
|
|
|
var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = [] |
|
|
|
|
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
if not isinstance(processor, EntityEmbedding): |
|
|
|
|
|
|
encodes.append(processed_obs) |
|
|
|
else: |
|
|
|
var_len_inputs.append(inputs[idx]) |
|
|
|
var_len_processor_inputs.append((processor, inputs[idx])) |
|
|
|
if len(var_len_inputs) > 0: |
|
|
|
if len(var_len_processor_inputs) > 0: |
|
|
|
masks = get_zero_entities_mask(var_len_inputs) |
|
|
|
masks = get_zero_entities_mask([p_i[1] for p_i in var_len_processor_inputs]) |
|
|
|
for var_len_input, processor in zip(var_len_inputs, var_len_processors): |
|
|
|
for processor, var_len_input in var_len_processor_inputs: |
|
|
|
embeddings.append(processor(processed_self, var_len_input)) |
|
|
|
qkv = torch.cat(embeddings, dim=1) |
|
|
|
attention_embedding = self.rsa(qkv, masks) |
|
|
|