浏览代码

use singular entity embedding (#4873)

/develop/centralizedcritic/counterfact
GitHub 4 年前
当前提交
89b6c949
共有 4 个文件被更改,包括 119 次插入103 次删除
  1. 26
      ml-agents/mlagents/trainers/tests/torch/test_attention.py
  2. 156
      ml-agents/mlagents/trainers/torch/attention.py
  3. 14
      ml-agents/mlagents/trainers/torch/layers.py
  4. 26
      ml-agents/mlagents/trainers/torch/networks.py

26
ml-agents/mlagents/trainers/tests/torch/test_attention.py


from mlagents.trainers.torch.layers import linear_layer
from mlagents.trainers.torch.attention import (
MultiHeadAttention,
EntityEmbeddings,
EntityEmbedding,
ResidualSelfAttention,
)

input_1 = generate_input_helper(masking_pattern_1)
input_2 = generate_input_helper(masking_pattern_2)
masks = EntityEmbeddings.get_masks([input_1, input_2])
masks = ResidualSelfAttention.get_masks([input_1, input_2])
assert len(masks) == 2
masks_1 = masks[0]
masks_2 = masks[1]

torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k])
transformer = ResidualSelfAttention(embedding_size, [n_k])
entity_embeddings = EntityEmbedding(size, size, n_k, embedding_size)
transformer = ResidualSelfAttention(embedding_size, n_k)
point_range = 3
init_error = -1.0
for _ in range(250):
center = torch.rand((batch_size, size)) * point_range * 2 - point_range
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range
for _ in range(200):
center = torch.rand((batch_size, size))
key = torch.rand((batch_size, n_k, size))
with torch.no_grad():
# create the target : The key closest to the query in euclidean distance
distance = torch.sum(

target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
embeddings = entity_embeddings(center, key)
masks = ResidualSelfAttention.get_masks([key])
if init_error == -1.0:
init_error = error.item()
else:
assert error.item() < init_error
assert error.item() < 0.3
assert error.item() < 0.02

156
ml-agents/mlagents/trainers/torch/attention.py


from mlagents.torch_utils import torch
from typing import Tuple, Optional, List
from mlagents.trainers.torch.layers import LinearEncoder, Initialization, linear_layer
from mlagents.trainers.torch.layers import (
LinearEncoder,
Initialization,
linear_layer,
LayerNorm,
)
"""
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
"""
"""
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
:param embedding_size: The size of the embeddings that will be generated (should be
dividable by the num_heads)
:param total_max_elements: The maximum total number of entities that can be passed to
the module
:param num_heads: The number of heads of the attention module
"""
super().__init__()
self.n_heads = num_heads
self.head_size: int = embedding_size // self.n_heads

return value_attention, att
class EntityEmbeddings(torch.nn.Module):
"""
A module used to embed entities before passing them to a self-attention block.
Used in conjunction with ResidualSelfAttention to encode information about a self
and additional entities. Can also concatenate self to entities for ego-centric self-
attention. Inspired by architecture used in https://arxiv.org/pdf/1909.07528.pdf.
"""
class EntityEmbedding(torch.nn.Module):
entity_sizes: List[int],
entity_size: int,
entity_num_max_elements: Optional[int],
entity_num_max_elements: Optional[List[int]] = None,
:param entity_sizes: List of sizes for other entities. Should be of length
equivalent to the number of entities.
:param embedding_size: Embedding size for entity encoders.
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
:param entity_size: Size of other entitiy.
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
:param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric
:param embedding_size: Embedding size for the entity encoder.
:param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric
self.entity_sizes: List[int] = entity_sizes
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes)
self.entity_size: int = entity_size
self.entity_num_max_elements: int = -1
if entity_num_max_elements is not None:
self.entity_num_max_elements = entity_num_max_elements

self.self_size = 0
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(
self.self_size + ent_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
for ent_size in self.entity_sizes
]
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.ent_encoder = LinearEncoder(
self.self_size + self.entity_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
def forward(
self, x_self: torch.Tensor, entities: List[torch.Tensor]
) -> Tuple[torch.Tensor, int]:
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
num_entities = self.entity_num_max_elements
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = entities.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
else:
self_and_ent = entities
# Encode and concatenate entites
encoded_entities = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
)
entities = torch.cat([expanded_self, entities], dim=2)
# Encode entities
encoded_entities = self.ent_encoder(entities)
@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in observations
]
return key_masks
class ResidualSelfAttention(torch.nn.Module):
"""

def __init__(
self,
embedding_size: int,
entity_num_max_elements: Optional[List[int]] = None,
entity_num_max_elements: Optional[int] = None,
num_heads: int = 4,
):
"""

super().__init__()
self.max_num_ent: Optional[int] = None
if entity_num_max_elements is not None:
_entity_num_max_elements = entity_num_max_elements
self.max_num_ent = sum(_entity_num_max_elements)
self.max_num_ent = entity_num_max_elements
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.fc_q = linear_layer(
embedding_size,
embedding_size,

kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
self.residual_norm = LayerNorm()
inp = self.embedding_norm(inp)
# Feed to self attention
query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)

output, _ = self.attention(query, key, value, num_ent, num_ent, mask)
# Residual
output = self.fc_out(output) + inp
output = self.residual_norm(output)
# Residual between x_self and the output of the module
@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in observations
]
return key_masks

14
ml-agents/mlagents/trainers/torch/layers.py


pass
class LayerNorm(torch.nn.Module):
"""
A vanilla implementation of layer normalization https://arxiv.org/pdf/1607.06450.pdf
norm_x = (x - mean) / sqrt((x - mean) ^ 2)
This does not include the trainable parameters gamma and beta for performance speed.
Typically, this is norm_x * gamma + beta
"""
def forward(self, layer_activations: torch.Tensor) -> torch.Tensor:
mean = torch.mean(layer_activations, dim=-1, keepdim=True)
var = torch.mean((layer_activations - mean) ** 2, dim=-1, keepdim=True)
return (layer_activations - mean) / (torch.sqrt(var + 1e-5))
class LinearEncoder(torch.nn.Module):
"""
Linear layers.

26
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.attention import ResidualSelfAttention, EntityEmbeddings
from mlagents.trainers.torch.attention import ResidualSelfAttention, EntityEmbedding
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]

+ sum(self.action_spec.discrete_branches)
+ self.action_spec.continuous_size
)
self.entity_encoder = EntityEmbeddings(
0, [obs_only_ent_size, q_ent_size], self.h_size, concat_self=False
self.obs_encoder = EntityEmbedding(
0, obs_only_ent_size, None, self.h_size, concat_self=False
self.obs_action_encoder = EntityEmbedding(
0, q_ent_size, None, self.h_size, concat_self=False
)
self.self_attn = ResidualSelfAttention(self.h_size)
encoder_input_size = self.h_size

value_masks = self._get_masks_from_nans(value_inputs)
q_masks = self._get_masks_from_nans(q_inputs)
encoded_entity = self.entity_encoder(None, [value_input_concat, q_input_concat])
encoded_obs = self.obs_encoder(None, value_input_concat)
encoded_obs_action = self.obs_action_encoder(None, q_input_concat)
encoded_entity = torch.cat([encoded_obs, encoded_obs_action], dim=1)
encoded_state = self.self_attn(encoded_entity, [value_masks, q_masks])
if len(concat_encoded_obs) == 0:

all_net_inputs, [], [], memories=critic_mem, sequence_length=sequence_length
)
value_outputs, critic_mem_out = self.critic(
critic_obs, [inputs], [actions], memories=critic_mem, sequence_length=sequence_length
critic_obs,
[inputs],
[actions],
memories=critic_mem,
sequence_length=sequence_length,
)
if mar_value_outputs is None:
mar_value_outputs = value_outputs

)
value_outputs, critic_mem_outs = self.critic(
[inputs], critic_obs, actions, memories=critic_mem, sequence_length=sequence_length
[inputs],
critic_obs,
actions,
memories=critic_mem,
sequence_length=sequence_length,
)
return log_probs, entropies, value_outputs, mar_value_outputs

正在加载...
取消
保存