|
|
|
|
|
|
from mlagents.trainers.settings import NetworkSettings |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, LayerNorm |
|
|
|
from mlagents.trainers.torch.encoders import VectorInput |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.trajectory import ObsUtil |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
if len(self.var_processors) > 0: |
|
|
|
self.embedding_norm = LayerNorm() |
|
|
|
self.rsa = ResidualSelfAttention(self.h_size, entity_num_max) |
|
|
|
total_enc_size = sum(self.embedding_sizes) + self.h_size |
|
|
|
n_layers = max(1, network_settings.num_layers - 2) |
|
|
|
|
|
|
): |
|
|
|
embeddings.append(var_len_processor(encoded_self, var_len_input)) |
|
|
|
qkv = torch.cat(embeddings, dim=1) |
|
|
|
qkv = self.embedding_norm(qkv) |
|
|
|
attention_embedding = self.rsa(qkv, masks) |
|
|
|
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1) |
|
|
|
|
|
|
|