浏览代码

refactor entityembedding/network body

/develop/singular-embeddings
Andrew Cohen 4 年前
当前提交
ad807327
共有 3 个文件被更改,包括 73 次插入75 次删除
  1. 100
      ml-agents/mlagents/trainers/torch/attention.py
  2. 28
      ml-agents/mlagents/trainers/torch/networks.py
  3. 20
      ml-agents/mlagents/trainers/torch/utils.py

100
ml-agents/mlagents/trainers/torch/attention.py


return value_attention, att
class EntityEmbeddings(torch.nn.Module):
class EntityEmbedding(torch.nn.Module):
entity_sizes: List[int],
entity_num_max_elements: Optional[List[int]],
entity_size: int,
entity_num_max_elements: Optional[int],
embedding_size: int,
concat_self: bool = True,
):

:param entity_sizes: List of sizes for other entities. Should be of length
equivalent to the number of entities.
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
:param entity_size: Size of other entitiy.
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
:param embedding_size: Embedding size for entity encoders.
:param embedding_size: Embedding size for the entity encoder.
self.entity_sizes: List[int] = entity_sizes
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes)
self.entity_sizes: int = entity_size
self.entity_num_max_elements: int = -1
if entity_num_max_elements is not None:
self.entity_num_max_elements = entity_num_max_elements

self.self_size = 0
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(
self.self_size + ent_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
for ent_size in self.entity_sizes
]
self.ent_encoder = LinearEncoder(
self.self_size + self.entity_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
def forward(
self, x_self: torch.Tensor, entities: List[torch.Tensor]
) -> torch.Tensor:
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
num_entities = self.entity_num_max_elements
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = entities.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
else:
self_and_ent = entities
# Encode and concatenate entities
encoded_entities = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
)
entities = torch.cat([expanded_self, entities], dim=2)
# Encode and entities
encoded_entities = self.ent_encoder(entities)
@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in observations
]
return key_masks
class ResidualSelfAttention(torch.nn.Module):

denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
output = numerator / denominator
return output
@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in observations
]
return key_masks

28
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.attention import EntityEmbeddings, ResidualSelfAttention
from mlagents.trainers.torch.attention import ResidualSelfAttention
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]

else 0
)
self.processors, self.embedding_sizes, var_len_indices = ModelUtils.create_input_processors(
self.processors, self.var_processors, self.embedding_sizes, entity_num_max = ModelUtils.create_input_processors(
observation_specs,
self.h_size,
network_settings.vis_encode_type,

if len(var_len_indices) > 0:
if len(self.var_processors) > 0:
self.rsa = ResidualSelfAttention(self.h_size, entity_num_max)
entities_sizes = [
observation_specs[idx].shape[1] for idx in var_len_indices
]
entities_max_len = [
observation_specs[idx].shape[0] for idx in var_len_indices
]
self.entities_embeddings = EntityEmbeddings(
self.h_size, entities_sizes, entities_max_len, self.h_size, False
)
self.rsa = ResidualSelfAttention(self.h_size, sum(entities_max_len))
total_enc_size = x_self_len + self.h_size
n_layers = max(1, network_settings.num_layers - 2)

encoded_self = torch.cat(encodes, dim=1)
if len(var_len_inputs) > 0:
# Some inputs need to be processed with a variable length encoder
masks = EntityEmbeddings.get_masks(var_len_inputs)
qkv = self.entities_embeddings(encoded_self, var_len_inputs)
masks = ResidualSelfAttention.get_masks(var_len_inputs)
embeddings: List[torch.Tensor] = []
for var_len_input, var_len_processor in zip(
var_len_inputs, self.var_processors
):
embeddings.append(var_len_processor(encoded_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
attention_embedding = self.rsa(qkv, masks)
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)

20
ml-agents/mlagents/trainers/torch/utils.py


VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.attention import EntityEmbedding
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import ObservationSpec, DimensionProperty

h_size: int,
vis_encode_type: EncoderType,
normalize: bool = False,
) -> Tuple[nn.ModuleList, List[int], List[int]]:
) -> Tuple[nn.ModuleList, nn.ModuleList, List[int], int]:
"""
Creates visual and vector encoders, along with their normalizers.
:param observation_specs: List of ObservationSpec that represent the observation dimensions.

- A list of the inputs that need to be processed by a variable length observation encoder.
"""
encoders: List[nn.Module] = []
var_encoders: List[nn.Module] = []
entity_num_max_elements = 0
for idx, obs_spec in enumerate(observation_specs):
encoder, embedding_size = ModelUtils.get_encoder_for_obs(
obs_spec, normalize, h_size, vis_encode_type

if encoder is None:
var_len_indices.append(idx)
return (nn.ModuleList(encoders), embedding_sizes, var_len_indices)
x_self_size = sum(embedding_sizes) # The size of the "self" embedding
for idx in var_len_indices:
entity_max: int = obs_spec[idx].shape[0]
var_encoders.append(
EntityEmbedding(x_self_size, obs_spec[idx].shape[1], entity_max, h_size)
)
entity_num_max_elements += entity_max
return (
nn.ModuleList(encoders),
nn.ModuleList(var_encoders),
embedding_sizes,
entity_num_max_elements,
)
@staticmethod
def list_to_tensor(

正在加载...
取消
保存