|
|
|
|
|
|
self, |
|
|
|
sensor_specs: List[SensorSpec], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
encoded_act_size: int = 0, |
|
|
|
action_spec: ActionSpec, |
|
|
|
num_obs_heads: int = 1, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
|
|
|
network_settings.vis_encode_type, |
|
|
|
normalize=self.normalize, |
|
|
|
) |
|
|
|
self.action_spec = action_spec |
|
|
|
obs_only_ent_size = sum(_input_size) |
|
|
|
q_ent_size = ( |
|
|
|
sum(_input_size) |
|
|
|
+ sum(self.action_spec.discrete_branches) |
|
|
|
+ self.action_spec.continuous_size |
|
|
|
) |
|
|
|
sum(_input_size), [sum(_input_size)], self.h_size |
|
|
|
0, [obs_only_ent_size, q_ent_size], self.h_size, concat_self=False |
|
|
|
total_enc_size = encoder_input_size + encoded_act_size |
|
|
|
total_enc_size, network_settings.num_layers, self.h_size |
|
|
|
encoder_input_size, network_settings.num_layers, self.h_size |
|
|
|
) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
|
|
|
if isinstance(n1, VectorInput) and isinstance(n2, VectorInput): |
|
|
|
n1.copy_normalization(n2) |
|
|
|
|
|
|
|
def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor: |
|
|
|
""" |
|
|
|
Get attention masks by grabbing an arbitrary obs across all the agents |
|
|
|
Since these are raw obs, the padded values are still NaN |
|
|
|
""" |
|
|
|
only_first_obs = [_all_obs[0] for _all_obs in obs_tensors] |
|
|
|
obs_for_mask = torch.stack(only_first_obs, dim=1) |
|
|
|
# Get the mask from nans |
|
|
|
attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor) |
|
|
|
return attn_mask |
|
|
|
|
|
|
|
all_net_inputs: List[List[torch.Tensor]], |
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
value_inputs: List[List[torch.Tensor]], |
|
|
|
q_inputs: List[List[torch.Tensor]], |
|
|
|
q_actions: List[AgentAction], |
|
|
|
concat_encoded_obs = [] |
|
|
|
x_self = None |
|
|
|
self_encodes = [] |
|
|
|
inputs = all_net_inputs[0] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
self_encodes.append(processed_obs) |
|
|
|
x_self = torch.cat(self_encodes, dim=-1) |
|
|
|
|
|
|
|
# Get attention masks by grabbing an arbitrary obs across all the agents |
|
|
|
# Since these are raw obs, the padded values are still NaN |
|
|
|
only_first_obs = [_all_obs[0] for _all_obs in all_net_inputs] |
|
|
|
obs_for_mask = torch.stack(only_first_obs, dim=1) |
|
|
|
# Get the mask from nans |
|
|
|
attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor) |
|
|
|
# Get the self encoding separately, but keep it in the entities |
|
|
|
concat_enc_q_obs = [] |
|
|
|
for inputs, actions in zip(q_inputs, q_actions): |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
cat_encodes = [ |
|
|
|
torch.cat(encodes, dim=-1), |
|
|
|
actions.to_flat(self.action_spec.discrete_branches), |
|
|
|
] |
|
|
|
concat_enc_q_obs.append(torch.cat(cat_encodes, dim=-1)) |
|
|
|
q_input_concat = torch.stack(concat_enc_q_obs, dim=1) |
|
|
|
concat_encoded_obs = [x_self] |
|
|
|
for inputs in all_net_inputs[1:]: |
|
|
|
concat_encoded_obs = [] |
|
|
|
for inputs in value_inputs: |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
|
|
|
|
concat_entites = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
value_input_concat = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
encoded_entity = self.entity_encoder(x_self, [concat_entites]) |
|
|
|
encoded_state = self.self_attn(encoded_entity, [attn_mask]) |
|
|
|
# Get the mask from nans |
|
|
|
value_masks = self._get_masks_from_nans(value_inputs) |
|
|
|
q_masks = self._get_masks_from_nans(q_inputs) |
|
|
|
|
|
|
|
encoded_entity = self.entity_encoder(None, [value_input_concat, q_input_concat]) |
|
|
|
encoded_state = self.self_attn(encoded_entity, [value_masks, q_masks]) |
|
|
|
# Constants don't work in Barracuda |
|
|
|
if actions is not None: |
|
|
|
inputs = torch.cat([encoded_state, actions], dim=-1) |
|
|
|
else: |
|
|
|
inputs = encoded_state |
|
|
|
inputs = encoded_state |
|
|
|
encoding = self.linear_encoder(inputs) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|