|
|
|
|
|
|
Since these are raw obs, the padded values are still NaN |
|
|
|
""" |
|
|
|
only_first_obs = [_all_obs[0] for _all_obs in obs_tensors] |
|
|
|
# flatten for correct dimensions with visual obs |
|
|
|
# Just get the first element in each obs regardless of its dimension. This will speed up |
|
|
|
# searching for NaNs. |
|
|
|
# Get the mask from NaNs |
|
|
|
attn_mask = only_first_obs_flat.isnan().type(torch.FloatTensor) |
|
|
|
return attn_mask |
|
|
|
|
|
|
|
|
|
|
self_attn_masks = [] |
|
|
|
self_attn_inputs = [] |
|
|
|
concat_f_inp = [] |
|
|
|
for inputs, action in zip(obs, actions): |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
cat_encodes = [ |
|
|
|
torch.cat(encodes, dim=-1), |
|
|
|
action.to_flat(self.action_spec.discrete_branches), |
|
|
|
] |
|
|
|
concat_f_inp.append(torch.cat(cat_encodes, dim=1)) |
|
|
|
|
|
|
|
if concat_f_inp: |
|
|
|
if obs: |
|
|
|
obs_attn_mask = self._get_masks_from_nans(obs) |
|
|
|
for i_agent, (inputs, action) in enumerate(zip(obs, actions)): |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[ |
|
|
|
obs_attn_mask.type(torch.BoolTensor)[:, i_agent], :: |
|
|
|
] = 0.0 # Remoove NaNs fast |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
cat_encodes = [ |
|
|
|
torch.cat(encodes, dim=-1), |
|
|
|
action.to_flat(self.action_spec.discrete_branches), |
|
|
|
] |
|
|
|
concat_f_inp.append(torch.cat(cat_encodes, dim=1)) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(obs)) |
|
|
|
self_attn_masks.append(obs_attn_mask) |
|
|
|
for inputs in obs_only: |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
g_inp = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(obs_only)) |
|
|
|
self_attn_inputs.append(self.obs_encoder(None, g_inp)) |
|
|
|
if obs_only: |
|
|
|
obs_only_attn_mask = self._get_masks_from_nans(obs_only) |
|
|
|
for i_agent, inputs in enumerate(obs_only): |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[ |
|
|
|
obs_only_attn_mask.type(torch.BoolTensor)[:, i_agent], :: |
|
|
|
] = 0.0 # Remoove NaNs fast |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
g_inp = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
self_attn_masks.append(obs_only_attn_mask) |
|
|
|
self_attn_inputs.append(self.obs_encoder(None, g_inp)) |
|
|
|
|
|
|
|
encoded_entity = torch.cat(self_attn_inputs, dim=1) |
|
|
|
encoded_state = self.self_attn(encoded_entity, self_attn_masks) |
|
|
|