浏览代码

move embedding norm into RSA block

/develop/singular-embeddings
Andrew Cohen 4 年前
当前提交
86d4c5c5
共有 2 个文件被更改,包括 4 次插入3 次删除
  1. 3
      ml-agents/mlagents/trainers/torch/attention.py
  2. 4
      ml-agents/mlagents/trainers/torch/networks.py

3
ml-agents/mlagents/trainers/torch/attention.py


kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
inp = self.embedding_norm(inp)
# Feed to self attention
query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)

4
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, LayerNorm
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil

if entity_max > 0:
entity_num_max += entity_max
if len(self.var_processors) > 0:
self.embedding_norm = LayerNorm()
self.rsa = ResidualSelfAttention(self.h_size, entity_num_max)
total_enc_size = sum(self.embedding_sizes) + self.h_size
n_layers = max(1, network_settings.num_layers - 2)

):
embeddings.append(var_len_processor(encoded_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
qkv = self.embedding_norm(qkv)
attention_embedding = self.rsa(qkv, masks)
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)

正在加载...
取消
保存