|
|
|
|
|
|
|
|
|
|
def baseline( |
|
|
|
self, |
|
|
|
self_obs: List[List[torch.Tensor]], |
|
|
|
self_obs: List[torch.Tensor], |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
actions: List[AgentAction], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
|
|
|
|
|
|
|
f_inp = None |
|
|
|
concat_f_inp = [] |
|
|
|
concat_g_inp = [] |
|
|
|
for inputs, action in zip(obs, actions): |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
|
|
|
encodes.append(processed_obs) |
|
|
|
concat_g_inp.append(torch.cat(encodes, dim=-1)) |
|
|
|
cat_encodes = [ |
|
|
|
torch.cat(encodes, dim=-1), |
|
|
|
action.to_flat(self.action_spec.discrete_branches), |
|
|
|
|
|
|
f_inp = torch.stack(concat_f_inp, dim=1) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(obs)) |
|
|
|
|
|
|
|
concat_encoded_obs = [] |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = self_obs[idx] |
|
|
|
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
g_inp = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
concat_g_inp.append(torch.cat(encodes, dim=-1)) |
|
|
|
g_inp = torch.stack(concat_g_inp, dim=1) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans([self_obs])) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans([self_obs] + obs)) |
|
|
|
encoding, memories = self.forward( |
|
|
|
f_inp, |
|
|
|
g_inp, |
|
|
|