浏览代码

refactor entityembedding/network body (#4857)

/develop/singular-embeddings
GitHub 4 年前
当前提交
01e0ee00
共有 4 个文件被更改,包括 81 次插入86 次删除
  1. 10
      ml-agents/mlagents/trainers/tests/torch/test_attention.py
  2. 105
      ml-agents/mlagents/trainers/torch/attention.py
  3. 38
      ml-agents/mlagents/trainers/torch/networks.py
  4. 14
      ml-agents/mlagents/trainers/torch/utils.py

10
ml-agents/mlagents/trainers/tests/torch/test_attention.py


from mlagents.trainers.torch.layers import linear_layer
from mlagents.trainers.torch.attention import (
MultiHeadAttention,
EntityEmbeddings,
EntityEmbedding,
ResidualSelfAttention,
)

input_1 = generate_input_helper(masking_pattern_1)
input_2 = generate_input_helper(masking_pattern_2)
masks = EntityEmbeddings.get_masks([input_1, input_2])
masks = ResidualSelfAttention.get_masks([input_1, input_2])
assert len(masks) == 2
masks_1 = masks[0]
masks_2 = masks[1]

torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size)
entity_embeddings = EntityEmbedding(size, size, n_k, embedding_size)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(

target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
embeddings = entity_embeddings(center, key)
masks = ResidualSelfAttention.get_masks([key])
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))

105
ml-agents/mlagents/trainers/torch/attention.py


return value_attention, att
class EntityEmbeddings(torch.nn.Module):
class EntityEmbedding(torch.nn.Module):
entity_sizes: List[int],
entity_num_max_elements: Optional[List[int]],
entity_size: int,
entity_num_max_elements: Optional[int],
embedding_size: int,
concat_self: bool = True,
):

:param entity_sizes: List of sizes for other entities. Should be of length
equivalent to the number of entities.
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
:param entity_size: Size of other entitiy.
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
:param embedding_size: Embedding size for entity encoders.
:param embedding_size: Embedding size for the entity encoder.
self.entity_sizes: List[int] = entity_sizes
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes)
self.entity_size: int = entity_size
self.entity_num_max_elements: int = -1
if entity_num_max_elements is not None:
self.entity_num_max_elements = entity_num_max_elements

self.self_size = 0
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(
self.self_size + ent_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
for ent_size in self.entity_sizes
]
self.ent_encoder = LinearEncoder(
self.self_size + self.entity_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
self.embedding_norm = LayerNorm()
def forward(
self, x_self: torch.Tensor, entities: List[torch.Tensor]
) -> torch.Tensor:
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
num_entities = self.entity_num_max_elements
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = entities.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
else:
self_and_ent = entities
# Encode and concatenate entities
encoded_entities = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
)
encoded_entities = self.embedding_norm(encoded_entities)
entities = torch.cat([expanded_self, entities], dim=2)
# Encode entities
encoded_entities = self.ent_encoder(entities)
@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in observations
]
return key_masks
class ResidualSelfAttention(torch.nn.Module):
"""

kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
inp = self.embedding_norm(inp)
# Feed to self attention
query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)

denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
output = numerator / denominator
return output
@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in observations
]
return key_masks

38
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.attention import EntityEmbeddings, ResidualSelfAttention
from mlagents.trainers.torch.attention import ResidualSelfAttention
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]

else 0
)
self.processors, self.embedding_sizes, var_len_indices = ModelUtils.create_input_processors(
self.processors, self.var_processors, self.embedding_sizes = ModelUtils.create_input_processors(
observation_specs,
self.h_size,
network_settings.vis_encode_type,

if len(var_len_indices) > 0:
# There are some variable length observations and they need to be processed separately
x_self_len = sum(self.embedding_sizes) # The size of the "self" embedding
entities_sizes = [
observation_specs[idx].shape[1] for idx in var_len_indices
]
entities_max_len = [
observation_specs[idx].shape[0] for idx in var_len_indices
]
self.entities_embeddings = EntityEmbeddings(
self.h_size, entities_sizes, entities_max_len, self.h_size, False
)
self.rsa = ResidualSelfAttention(self.h_size, sum(entities_max_len))
total_enc_size = x_self_len + self.h_size
entity_num_max: int = 0
for var_processor in self.var_processors:
entity_max: int = var_processor.entity_num_max_elements
# Only adds entity max if it was known at construction
if entity_max > 0:
entity_num_max += entity_max
if len(self.var_processors) > 0:
self.rsa = ResidualSelfAttention(self.h_size, entity_num_max)
total_enc_size = sum(self.embedding_sizes) + self.h_size
n_layers = max(1, network_settings.num_layers - 2)
else:
total_enc_size = sum(self.embedding_sizes)

encoded_self = torch.cat(encodes, dim=1)
if len(var_len_inputs) > 0:
# Some inputs need to be processed with a variable length encoder
masks = EntityEmbeddings.get_masks(var_len_inputs)
qkv = self.entities_embeddings(encoded_self, var_len_inputs)
masks = ResidualSelfAttention.get_masks(var_len_inputs)
embeddings: List[torch.Tensor] = []
for var_len_input, var_len_processor in zip(
var_len_inputs, self.var_processors
):
embeddings.append(var_len_processor(encoded_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
attention_embedding = self.rsa(qkv, masks)
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)

14
ml-agents/mlagents/trainers/torch/utils.py


VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.attention import EntityEmbedding
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import ObservationSpec, DimensionProperty

h_size: int,
vis_encode_type: EncoderType,
normalize: bool = False,
) -> Tuple[nn.ModuleList, List[int], List[int]]:
) -> Tuple[nn.ModuleList, nn.ModuleList, List[int]]:
"""
Creates visual and vector encoders, along with their normalizers.
:param observation_specs: List of ObservationSpec that represent the observation dimensions.

- A list of the inputs that need to be processed by a variable length observation encoder.
"""
encoders: List[nn.Module] = []
var_encoders: List[nn.Module] = []
embedding_sizes: List[int] = []
var_len_indices: List[int] = []
for idx, obs_spec in enumerate(observation_specs):

embedding_sizes.append(embedding_size)
if encoder is None:
var_len_indices.append(idx)
return (nn.ModuleList(encoders), embedding_sizes, var_len_indices)
x_self_size = sum(embedding_sizes) # The size of the "self" embedding
for idx in var_len_indices:
var_encoders.append(
EntityEmbedding(
x_self_size, obs_spec[idx].shape[1], obs_spec[idx].shape[0], h_size
)
)
return (nn.ModuleList(encoders), nn.ModuleList(var_encoders), embedding_sizes)
@staticmethod
def list_to_tensor(

正在加载...
取消
保存