|
|
|
|
|
|
self_attn_masks.append(self._get_masks_from_nans(obs)) |
|
|
|
|
|
|
|
concat_encoded_obs = [] |
|
|
|
for inputs in self_obs: |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = self_obs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(self_obs)) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans([self_obs])) |
|
|
|
encoding, memories = self.forward( |
|
|
|
f_inp, |
|
|
|
g_inp, |
|
|
|