浏览代码

Fixing some bugs

/exp-alternate-atten
vincentpierre 4 年前
当前提交
0b6c2ed3
共有 2 个文件被更改,包括 104 次插入67 次删除
  1. 58
      ml-agents/mlagents/trainers/torch/layers.py
  2. 113
      ml-agents/mlagents/trainers/torch/networks.py

58
ml-agents/mlagents/trainers/torch/layers.py


from mlagents.torch_utils import torch
import abc
from typing import Tuple, List
from typing import Tuple, List, Optional
from enum import Enum

key: torch.Tensor,
value: torch.Tensor,
key_mask: torch.Tensor,
number_of_keys: int = -1,
number_of_queries: int = -1
b, n_q, n_k = query.size(0), query.size(1), key.size(1)
b = -1 # the batch size
# This is to avoid using .size() when possible as Barracuda does not support
n_q = number_of_queries if number_of_queries != -1 else query.size(1)
n_k = number_of_keys if number_of_keys != -1 else key.size(1)
# Create a key mask : Only 1 if all values are 0 # shape = (b, n_k)
# key_mask = torch.sum(key ** 2, axis=2) < 0.01

qk = torch.matmul(query, key) # (b, h, n_q, n_k)
qk = qk / (self.embedding_size ** 0.5) + key_mask * self.NEG_INF
qk = (1 - key_mask) * qk / (self.embedding_size ** 0.5) + key_mask * self.NEG_INF
att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k)

out = self.fc_out(value_attention) # (b, n_q, emb)
return out, att
class ZeroObservationMask(torch.nn.Module):
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros and 0 otherwise. This is used in the Attention layer to mask the padding
observations.
"""
def __init__(self):
super().__init__()
def forward(self, observations: List[torch.Tensor]):
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in observations
]
return key_masks
class SimpleTransformer(torch.nn.Module):

super().__init__()
self.self_size = x_self_size
self.entities_sizes = entities_sizes
self.entities_num_max_elements: Optional[List[int]] = None
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(self.self_size + ent_size, 1, embedding_size)

)
self.residual_layer = LinearEncoder(embedding_size, 1, embedding_size)
def forward(self, x_self: torch.Tensor, entities: List[torch.Tensor]):
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor)
for ent in entities
]
def forward(self, x_self: torch.Tensor, entities: List[torch.Tensor], key_masks: List[torch.Tensor]):
# Gather the maximum number of entities information
if self.entities_num_max_elements is None:
self.entities_num_max_elements = []
for ent in entities:
self.entities_num_max_elements.append(ent.shape[1])
for ent_size, ent in zip(self.entities_sizes, entities):
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size).repeat(
1, num_entities, 1
)
for num_entities, ent in zip(self.entities_num_max_elements, entities):
expanded_self = x_self.reshape(-1, 1, self.self_size)
# .repeat(
# 1, num_entities, 1
# )
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
# Generate the tensor that will serve as query, key and value to self attention
qkv = torch.cat(

mask = torch.cat(key_masks, dim=1)
# Feed to self attention
output, _ = self.attention(qkv, qkv, qkv, mask)
max_num_ent = sum(self.entities_num_max_elements)
output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent, max_num_ent)
max_num_ent = qkv.shape[1]
numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1), dim=1)
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON
output = numerator / denominator

113
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, SimpleTransformer
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, SimpleTransformer, ZeroObservationMask
from mlagents.trainers.torch.model_serialization import exporting_to_onnx
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]

normalize=self.normalize,
)
emb_size = 64
self.transformer = SimpleTransformer(
x_self_size=64,
entities_sizes=[64], # hard coded, 4 obs per entity
embedding_size=emb_size,
)
self.use_fc = False
if not self.use_fc:
emb_size = 32
self.masking_module = ZeroObservationMask()
self.transformer = SimpleTransformer(
x_self_size=32,
entities_sizes=[32], # hard coded, 4 obs per entity
embedding_size=emb_size,
)
# total_enc_size = encoder_input_size + encoded_act_size
# total_enc_size = encoder_input_size + encoded_act_size
self.self_embedding = LinearEncoder(6, 2, 64)
self.obs_embeding = LinearEncoder(4, 2, 64)
# self.self_and_obs_embedding = LinearEncoder(64 + 64, 1, 64)
# self.dense_after_attention = LinearEncoder(64, 1, 64)
self.self_embedding = LinearEncoder(6, 2, 32)
self.obs_embeding = LinearEncoder(4, 2, 32)
# self.self_and_obs_embedding = LinearEncoder(64 + 64, 1, 64)
# self.dense_after_attention = LinearEncoder(64, 1, 64)
self.linear_encoder = LinearEncoder(
64 + 64, network_settings.num_layers, self.h_size
)
self.linear_encoder = LinearEncoder(
emb_size + 32, network_settings.num_layers - 1, self.h_size
)
else:
self.linear_encoder = LinearEncoder(
6 + 4 * 20, network_settings.num_layers + 2, self.h_size
)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)

else:
inputs = torch.cat(encodes, dim=-1)
x_self = processed_vec.reshape(-1, processed_vec.shape[1])
x_self = self.self_embedding(x_self)
var_len_input = vis_inputs[0].reshape(-1, 20, 4)
var_len_input = self.obs_embeding(var_len_input)
output = self.transformer(x_self, [var_len_input])
if not self.use_fc:
x_self = self.self_embedding(processed_vec)
var_len_input = vis_inputs[0].reshape(-1, 20, 4)
processed_var_len_input = self.obs_embeding(var_len_input)
output = self.transformer(x_self, [processed_var_len_input], self.masking_module([var_len_input]))
# # TODO : This is a Hack
# var_len_input = vis_inputs[0].reshape(-1, 20, 4)
# key_mask = (
# torch.sum(var_len_input ** 2, axis=2) < 0.01
# ).type(torch.FloatTensor) # 1 means mask and 0 means let though
# # TODO : This is a Hack
# var_len_input = vis_inputs[0].reshape(-1, 20, 4)
# key_mask = (
# torch.sum(var_len_input ** 2, axis=2) < 0.01
# ).type(torch.FloatTensor) # 1 means mask and 0 means let though
# x_self = processed_vec.reshape(-1, processed_vec.shape[1])
# x_self = self.self_embedding(x_self) # (b, 1,64)
# expanded_x_self = x_self.reshape(-1, 1, 64).repeat(1, 20, 1)
# x_self = processed_vec.reshape(-1, processed_vec.shape[1])
# x_self = self.self_embedding(x_self) # (b, 1,64)
# expanded_x_self = x_self.reshape(-1, 1, 64).repeat(1, 20, 1)
# obj_emb = self.obs_embeding(var_len_input)
# objects = torch.cat([expanded_x_self, obj_emb], dim=2) # (b,20,64)
# obj_emb = self.obs_embeding(var_len_input)
# objects = torch.cat([expanded_x_self, obj_emb], dim=2) # (b,20,64)
# obj_and_self = self.self_and_obs_embedding(objects) # (b,20,64)
# # add the self to the entities
# # self_and_key_emb = torch.cat(
# # [x_self.reshape(-1, 1, 64), obj_and_self], dim=1
# # ) # (b,21,64)
# # key_mask = torch.cat(
# # [torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1
# # ) # first one is never masked
# obj_and_self = self.self_and_obs_embedding(objects) # (b,20,64)
# # add the self to the entities
# # self_and_key_emb = torch.cat(
# # [x_self.reshape(-1, 1, 64), obj_and_self], dim=1
# # ) # (b,21,64)
# # key_mask = torch.cat(
# # [torch.zeros((self_and_key_emb.shape[0], 1)), key_mask], dim=1
# # ) # first one is never masked
# # output, _ = self.attention(
# # self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask
# # ) # (b, 21, 64)
# output, _ = self.attention(
# obj_and_self, obj_and_self, obj_and_self, key_mask
# ) # (b, 21, 64)
# output = self.dense_after_attention(output) + obj_and_self
# # output, _ = self.attention(
# # self_and_key_emb, self_and_key_emb, self_and_key_emb, key_mask
# # ) # (b, 21, 64)
# output, _ = self.attention(
# obj_and_self, obj_and_self, obj_and_self, key_mask
# ) # (b, 21, 64)
# output = self.dense_after_attention(output) + obj_and_self
# output = torch.sum(
# output * (1 - key_mask).reshape(-1, 20, 1), dim=1
# ) / (torch.sum(
# 1 - key_mask, dim=1, keepdim=True
# ) + 0.001 ) # average pooling
# output = torch.sum(
# output * (1 - key_mask).reshape(-1, 20, 1), dim=1
# ) / (torch.sum(
# 1 - key_mask, dim=1, keepdim=True
# ) + 0.001 ) # average pooling
encoding = self.linear_encoder(torch.cat([output, x_self], dim=1))
encoding = self.linear_encoder(torch.cat([output, x_self], dim=1))
else:
encoding = self.linear_encoder(torch.cat([vis_inputs[0].reshape(-1, 80), processed_vec], dim=1))
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)

正在加载...
取消
保存