|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ObservationEncoder(nn.Module): |
|
|
|
ATTENTION_EMBEDDING_SIZE = 128 # The embedding size of attention is fixed |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
observation_specs: List[ObservationSpec], |
|
|
|
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
self.processors, self.embedding_sizes = ModelUtils.create_input_processors( |
|
|
|
observation_specs, h_size, vis_encode_type, normalize=normalize |
|
|
|
observation_specs, |
|
|
|
h_size, |
|
|
|
vis_encode_type, |
|
|
|
self.ATTENTION_EMBEDDING_SIZE, |
|
|
|
normalize=normalize, |
|
|
|
self.processors, self.embedding_sizes, h_size |
|
|
|
self.processors, self.embedding_sizes, self.ATTENTION_EMBEDDING_SIZE |
|
|
|
total_enc_size = sum(self.embedding_sizes) + h_size |
|
|
|
total_enc_size = sum(self.embedding_sizes) + self.ATTENTION_EMBEDDING_SIZE |
|
|
|
else: |
|
|
|
total_enc_size = sum(self.embedding_sizes) |
|
|
|
self.normalize = normalize |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiAgentNetworkBody(torch.nn.Module): |
|
|
|
ATTENTION_EMBEDDING_SIZE = 128 |
|
|
|
|
|
|
|
""" |
|
|
|
A network body that uses a self attention layer to handle state |
|
|
|
and action input from a potentially variable number of agents that |
|
|
|
|
|
|
+ sum(self.action_spec.discrete_branches) |
|
|
|
+ self.action_spec.continuous_size |
|
|
|
) |
|
|
|
self.obs_encoder = EntityEmbedding(obs_only_ent_size, None, self.h_size) |
|
|
|
self.obs_action_encoder = EntityEmbedding(q_ent_size, None, self.h_size) |
|
|
|
|
|
|
|
self.obs_encoder = EntityEmbedding( |
|
|
|
obs_only_ent_size, None, self.ATTENTION_EMBEDDING_SIZE |
|
|
|
) |
|
|
|
self.obs_action_encoder = EntityEmbedding( |
|
|
|
q_ent_size, None, self.ATTENTION_EMBEDDING_SIZE |
|
|
|
) |
|
|
|
self.self_attn = ResidualSelfAttention(self.h_size) |
|
|
|
self.self_attn = ResidualSelfAttention(self.ATTENTION_EMBEDDING_SIZE) |
|
|
|
self.h_size, |
|
|
|
self.ATTENTION_EMBEDDING_SIZE, |
|
|
|
network_settings.num_layers, |
|
|
|
self.h_size, |
|
|
|
kernel_gain=(0.125 / self.h_size) ** 0.5, |
|
|
|
|
|
|
no_nan_obs = [] |
|
|
|
for obs in single_agent_obs: |
|
|
|
new_obs = obs.clone() |
|
|
|
new_obs[ |
|
|
|
attention_mask.bool()[:, i_agent], :: |
|
|
|
] = 0.0 # Remoove NaNs fast |
|
|
|
new_obs[attention_mask.bool()[:, i_agent], ::] = 0.0 # Remove NaNs fast |
|
|
|
no_nan_obs.append(new_obs) |
|
|
|
obs_with_no_nans.append(no_nan_obs) |
|
|
|
return obs_with_no_nans |
|
|
|