浏览代码

moved embedding layer norm into network body

/develop/singular-embeddings
Andrew Cohen 4 年前
当前提交
4e37974c
共有 2 个文件被更改,包括 4 次插入4 次删除
  1. 4
      ml-agents/mlagents/trainers/torch/attention.py
  2. 4
      ml-agents/mlagents/trainers/torch/networks.py

4
ml-agents/mlagents/trainers/torch/attention.py


kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
if self.concat_self:

expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
# Concatenate all observations with self
entities = torch.cat([expanded_self, entities], dim=2)
# Encode and entities
# Encode entities
encoded_entities = self.embedding_norm(encoded_entities)
return encoded_entities

4
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, LayerNorm
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil

)
if len(self.var_processors) > 0:
self.embedding_norm = LayerNorm()
self.rsa = ResidualSelfAttention(self.h_size, entity_num_max)
total_enc_size = sum(self.embedding_sizes) + self.h_size
n_layers = max(1, network_settings.num_layers - 2)

):
embeddings.append(var_len_processor(encoded_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
qkv = self.embedding_norm(qkv)
attention_embedding = self.rsa(qkv, masks)
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)

正在加载...
取消
保存