|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
# Tensors that go into ResidualSelfAttention |
|
|
|
self_attn_inputs = [] |
|
|
|
self_attn_masks = [] |
|
|
|
|
|
|
|
# Get the self encoding separately, but keep it in the entities |
|
|
|
concat_enc_q_obs = [] |
|
|
|
for inputs, actions in zip(q_inputs, q_actions): |
|
|
|
|
|
|
torch.cat(encodes, dim=-1), |
|
|
|
actions.to_flat(self.action_spec.discrete_branches), |
|
|
|
] |
|
|
|
concat_enc_q_obs.append(torch.cat(cat_encodes, dim=-1)) |
|
|
|
q_input_concat = torch.stack(concat_enc_q_obs, dim=1) |
|
|
|
concat_enc_q_obs.append(torch.cat(cat_encodes, dim=1)) |
|
|
|
if concat_enc_q_obs: |
|
|
|
q_input_concat = torch.stack(concat_enc_q_obs, dim=1) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(q_inputs)) |
|
|
|
encoded_obs_action = self.obs_action_encoder(None, q_input_concat) |
|
|
|
self_attn_inputs.append(encoded_obs_action) |
|
|
|
|
|
|
|
# Get the self encoding separately, but keep it in the entities |
|
|
|
concat_encoded_obs = [] |
|
|
|
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
|
|
|
|
value_input_concat = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
|
|
|
|
# Get the mask from nans |
|
|
|
value_masks = self._get_masks_from_nans(value_inputs) |
|
|
|
q_masks = self._get_masks_from_nans(q_inputs) |
|
|
|
|
|
|
|
encoded_obs = self.obs_encoder(None, value_input_concat) |
|
|
|
encoded_obs_action = self.obs_action_encoder(None, q_input_concat) |
|
|
|
encoded_entity = torch.cat([encoded_obs, encoded_obs_action], dim=1) |
|
|
|
encoded_state = self.self_attn(encoded_entity, [value_masks, q_masks]) |
|
|
|
if concat_encoded_obs: |
|
|
|
value_input_concat = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
# Get the mask from nans |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(value_inputs)) |
|
|
|
encoded_obs = self.obs_encoder(None, value_input_concat) |
|
|
|
self_attn_inputs.append(encoded_obs) |
|
|
|
if len(concat_encoded_obs) == 0: |
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|
encoded_entity = torch.cat(self_attn_inputs, dim=1) |
|
|
|
encoded_state = self.self_attn(encoded_entity, self_attn_masks) |
|
|
|
|
|
|
|
inputs = encoded_state |
|
|
|
encoding = self.linear_encoder(inputs) |
|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
critic_obs: List[List[torch.Tensor]] = None, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = self._get_actor_critic_mem(memories) |
|
|
|
all_net_inputs = [inputs] |
|
|
|
if critic_obs is not None and critic_obs: |
|
|
|