浏览代码

Refactoring the code to make it more flexible. Still a hack

/exp-alternate-atten
vincentpierre 4 年前
当前提交
6fcbba53
共有 3 个文件被更改,包括 123 次插入49 次删除
  1. 77
      ml-agents/mlagents/trainers/torch/layers.py
  2. 81
      ml-agents/mlagents/trainers/torch/networks.py
  3. 14
      ml-agents/mlagents/trainers/torch/utils.py

77
ml-agents/mlagents/trainers/torch/layers.py


from mlagents.torch_utils import torch
import abc
from typing import Tuple
from typing import Tuple, List
from enum import Enum

class MultiHeadAttention(torch.nn.Module):
"""
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, key_size)
- key: of dimensions (batch_size, number_of_keys, key_size)
- value: of dimensions (batch_size, number_of_keys, value_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, output_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
"""
NEG_INF = -1e6
def __init__(

out = self.fc_out(value_attention) # (b, n_q, emb)
return out, att
class SimpleTransformer(torch.nn.Module):
"""
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses
multi head self attention to encode information about a "Self" and a list of
relevant "Entities".
"""
EPISLON = 1e-7
def __init__(
self, x_self_size: int, entities_sizes: List[int], embedding_size: int
):
super().__init__()
self.self_size = x_self_size
self.entities_sizes = entities_sizes
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(self.self_size + ent_size, 1, embedding_size)
for ent_size in self.entities_sizes
]
)
self.attention = MultiHeadAttention(
query_size=embedding_size,
key_size=embedding_size,
value_size=embedding_size,
output_size=embedding_size,
num_heads=4,
embedding_size=embedding_size,
)
self.residual_layer = LinearEncoder(embedding_size, 1, embedding_size)
def forward(self, x_self: torch.Tensor, entities: List[torch.Tensor]):
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in entities
]
# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for ent_size, ent in zip(self.entities_sizes, entities):
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size).repeat(
1, num_entities, 1
)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
# Generate the tensor that will serve as query, key and value to self attention
qkv = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
)
mask = torch.cat(key_masks, dim=1)
# Feed to self attention
output, _ = self.attention(qkv, qkv, qkv, mask)
# Residual
output = self.residual_layer(output) + qkv
# Average Pooling
max_num_ent = qkv.shape[1]
numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1), dim=1)
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON
output = numerator / denominator
return output

81
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, SimpleTransformer
from mlagents.trainers.torch.model_serialization import exporting_to_onnx
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]

else 0
)
self.visual_processors, self.vector_processors, self.attention, _ = ModelUtils.create_input_processors(
self.visual_processors, self.vector_processors, encoder_input_size = ModelUtils.create_input_processors(
observation_shapes,
self.h_size,
network_settings.vis_encode_type,

emb_size = 64
self.transformer = SimpleTransformer(
x_self_size=encoder_input_size,
entities_sizes=[4], # hard coded, 4 obs per entity
embedding_size=emb_size,
)
self.self_embedding = LinearEncoder(6, 2, 64)
self.obs_embeding = LinearEncoder(4, 2, 64)
self.self_and_obs_embedding = LinearEncoder(64 + 64, 1, 64)
self.dense_after_attention = LinearEncoder(64, 1, 64)
# self.self_embedding = LinearEncoder(6, 2, 64)
# self.obs_embeding = LinearEncoder(4, 2, 64)
# self.self_and_obs_embedding = LinearEncoder(64 + 64, 1, 64)
# self.dense_after_attention = LinearEncoder(64, 1, 64)
64 * 2, network_settings.num_layers, self.h_size
emb_size + encoder_input_size, network_settings.num_layers, self.h_size
)
if self.use_lstm:

else:
inputs = torch.cat(encodes, dim=-1)
# TODO : This is a Hack
x_self = processed_vec.reshape(-1, processed_vec.shape[1])
key_mask = (
torch.sum(var_len_input ** 2, axis=2) < 0.01
).type(torch.FloatTensor) # 1 means mask and 0 means let though
output = self.transformer(x_self, [var_len_input])
x_self = processed_vec.reshape(-1, processed_vec.shape[1])
x_self = self.self_embedding(x_self) # (b, 1,64)
expanded_x_self = x_self.reshape(-1, 1, 64).repeat(1, 20, 1)
# # TODO : This is a Hack
# var_len_input = vis_inputs[0].reshape(-1, 20, 4)
# key_mask = (
# torch.sum(var_len_input ** 2, axis=2) < 0.01
# ).type(torch.FloatTensor) # 1 means mask and 0 means let though
obj_emb = self.obs_embeding(var_len_input)
objects = torch.cat([expanded_x_self, obj_emb], dim=2) # (b,20,64)
# x_self = processed_vec.reshape(-1, processed_vec.shape[1])
# x_self = self.self_embedding(x_self) # (b, 1,64)
# expanded_x_self = x_self.reshape(-1, 1, 64).repeat(1, 20, 1)
# obj_emb = self.obs_embeding(var_len_input)
# objects = torch.cat([expanded_x_self, obj_emb], dim=2) # (b,20,64)
obj_and_self = self.self_and_obs_embedding(objects) # (b,20,64)
# add the self to the entities
# self_and_key_emb = torch.cat(
# [x_self.reshape(-1, 1, 64), obj_and_self], dim=1
# ) # (b,21,64)
# key_mask = torch.cat(
# [torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1
# ) # first one is never masked
# obj_and_self = self.self_and_obs_embedding(objects) # (b,20,64)
# # add the self to the entities
# # self_and_key_emb = torch.cat(
# # [x_self.reshape(-1, 1, 64), obj_and_self], dim=1
# # ) # (b,21,64)
# # key_mask = torch.cat(
# # [torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1
# # ) # first one is never masked
# # output, _ = self.attention(
# # self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask
# # ) # (b, 21, 64)
# self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask
# obj_and_self, obj_and_self, obj_and_self, key_mask
output, _ = self.attention(
obj_and_self, obj_and_self, obj_and_self, key_mask
) # (b, 21, 64)
output = self.dense_after_attention(output) + obj_and_self
# output = self.dense_after_attention(output) + obj_and_self
output = torch.sum(
output * (1 - key_mask).reshape(-1, 20, 1), dim=1
) / (torch.sum(
1 - key_mask, dim=1, keepdim=True
) + 0.001 ) # average pooling
# output = torch.sum(
# output * (1 - key_mask).reshape(-1, 20, 1), dim=1
# ) / (torch.sum(
# 1 - key_mask, dim=1, keepdim=True
# ) + 0.001 ) # average pooling
encoding = self.linear_encoder(torch.cat([output , x_self], dim=1))
encoding = self.linear_encoder(torch.cat([output, x_self], dim=1))
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)

14
ml-agents/mlagents/trainers/torch/utils.py


# Total output size for all inputs + CNNs
total_processed_size = vector_size + visual_output_size
# HardCoded
max_observables, observable_size, output_size = (20, 4, 64)
attention = MultiHeadAttention(
query_size=output_size,
key_size=output_size,
value_size=output_size,
output_size=output_size,
num_heads=4,
embedding_size=64,
)
attention,
output_size, # total_processed_size + output_size,
total_processed_size,
)
@staticmethod

正在加载...
取消
保存