|
|
|
|
|
|
encodes.append(processed_obs) |
|
|
|
else: |
|
|
|
var_len_inputs.append(inputs[idx]) |
|
|
|
if len(encodes) == 0: |
|
|
|
encoded_self = torch.zeros(0, 0) |
|
|
|
else: |
|
|
|
if len(encodes) != 0: |
|
|
|
input_exist = True |
|
|
|
else: |
|
|
|
input_exist = False |
|
|
|
if len(var_len_inputs) > 0: |
|
|
|
# Some inputs need to be processed with a variable length encoder |
|
|
|
masks = get_zero_entities_mask(var_len_inputs) |
|
|
|
|
|
|
embeddings.append(var_len_processor(encoded_self, var_len_input)) |
|
|
|
qkv = torch.cat(embeddings, dim=1) |
|
|
|
attention_embedding = self.rsa(qkv, masks) |
|
|
|
if encoded_self.shape[1] == 0: |
|
|
|
if not input_exist: |
|
|
|
input_exist = True |
|
|
|
if encoded_self.shape[1] == 0: |
|
|
|
if not input_exist: |
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|
|
|
|
|
# Constants don't work in Barracuda |
|
|
|