|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiAgentNetworkBody(torch.nn.Module): |
|
|
|
ATTENTION_EMBEDDING_SIZE = 128 |
|
|
|
|
|
|
|
""" |
|
|
|
A network body that uses a self attention layer to handle state |
|
|
|
and action input from a potentially variable number of agents that |
|
|
|
|
|
|
+ self.action_spec.continuous_size |
|
|
|
) |
|
|
|
|
|
|
|
attention_embeding_size = self.h_size |
|
|
|
obs_only_ent_size, None, self.ATTENTION_EMBEDDING_SIZE |
|
|
|
obs_only_ent_size, None, attention_embeding_size |
|
|
|
q_ent_size, None, self.ATTENTION_EMBEDDING_SIZE |
|
|
|
q_ent_size, None, attention_embeding_size |
|
|
|
self.self_attn = ResidualSelfAttention(self.ATTENTION_EMBEDDING_SIZE) |
|
|
|
self.self_attn = ResidualSelfAttention(attention_embeding_size) |
|
|
|
self.ATTENTION_EMBEDDING_SIZE, |
|
|
|
attention_embeding_size, |
|
|
|
network_settings.num_layers, |
|
|
|
self.h_size, |
|
|
|
kernel_gain=(0.125 / self.h_size) ** 0.5, |
|
|
|