|
|
|
|
|
|
x_self = torch.cat(self_encodes, dim=-1) |
|
|
|
|
|
|
|
# Get attention masks by grabbing an arbitrary obs across all the agents |
|
|
|
# Since these are raw obs, the padded values are still 0 |
|
|
|
# Since these are raw obs, the padded values are still NaN |
|
|
|
# Get the mask from nans |
|
|
|
attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor) |
|
|
|
|
|
|
|
# Get the self encoding separately, but keep it in the entities |
|
|
|
concat_encoded_obs = [x_self] |
|
|
|
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
|
|
|
encoded_entity = self.entity_encoder(x_self, [concat_entites]) |
|
|
|
encoded_state = self.self_attn( |
|
|
|
encoded_entity, EntityEmbeddings.get_masks([obs_for_mask]) |
|
|
|
) |
|
|
|
encoded_state = self.self_attn(encoded_entity, [attn_mask]) |
|
|
|
|
|
|
|
if len(concat_encoded_obs) == 0: |
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|