浏览代码

custom layer norm

/layernorm
Andrew Cohen 4 年前
当前提交
96c01a63
共有 4 个文件被更改,包括 164 次插入169 次删除
  1. 25
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  2. 273
      ml-agents/mlagents/trainers/torch/attention.py
  3. 14
      ml-agents/mlagents/trainers/torch/layers.py
  4. 21
      ml-agents/mlagents/trainers/torch/networks.py

25
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


)
self.optimizer = torch.optim.Adam(
params, lr=self.trainer_settings.hyperparameters.learning_rate, weight_decay=1E-6
params,
lr=self.trainer_settings.hyperparameters.learning_rate,
weight_decay=1e-6,
)
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",

"Policy/Beta": decay_bet,
}
for name, params in list(self.policy.actor_critic.network_body.transformer.named_parameters()):
update_stats["Policy/" + name + '_mean'] = torch.mean(params).item()
update_stats["Policy/" + name + '_std'] = torch.std(params).item()
update_stats["Policy/" + name + '_grad_mag'] = torch.norm(params.grad).item()
update_stats["Policy/" + name + '_grad_mean'] = torch.mean(params.grad).item()
update_stats["Policy/" + name + '_grad_std'] = torch.std(params.grad).item()
for name, params in list(self.policy.actor_critic.network_body.linear_encoder.named_parameters()):
update_stats["Policy/" + name + '_grad_mag'] = torch.norm(params.grad).item()
update_stats["Policy/" + name + '_grad_mean'] = torch.mean(params.grad).item()
update_stats["Policy/" + name + '_grad_std'] = torch.std(params.grad).item()
# for name, params in list(self.policy.actor_critic.network_body.transformer.named_parameters()):
# update_stats["Policy/" + name + '_mean'] = torch.mean(params).item()
# update_stats["Policy/" + name + '_std'] = torch.std(params).item()
# update_stats["Policy/" + name + '_grad_mag'] = torch.norm(params.grad).item()
# update_stats["Policy/" + name + '_grad_mean'] = torch.mean(params.grad).item()
# update_stats["Policy/" + name + '_grad_std'] = torch.std(params.grad).item()
# for name, params in list(self.policy.actor_critic.network_body.linear_encoder.named_parameters()):
# update_stats["Policy/" + name + '_grad_mag'] = torch.norm(params.grad).item()
# update_stats["Policy/" + name + '_grad_mean'] = torch.mean(params.grad).item()
# update_stats["Policy/" + name + '_grad_std'] = torch.std(params.grad).item()
for reward_provider in self.reward_signals.values():
update_stats.update(reward_provider.update(batch))

273
ml-agents/mlagents/trainers/torch/attention.py


from typing import Tuple, Optional, List
from mlagents.trainers.torch.layers import (
LinearEncoder,
linear_layer,
linear_layer,
def grad_hook(mod, inp, out):
print("")
print(mod)
print("-" * 10 + " Incoming Gradients " + "-" * 10)
print("")
print("Incoming Grad value: {}".format(inp[0].data))
print("")
print("Upstream Grad value: {}".format(out[0].data))
- query: of dimensions (batch_size, number_of_queries, key_size)
- key: of dimensions (batch_size, number_of_keys, key_size)
- value: of dimensions (batch_size, number_of_keys, value_size)
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
- The output: (batch_size, number_of_queries, output_size)
- The output: (batch_size, number_of_queries, embedding_size)
def __init__(self, num_heads: int, embedding_size: int):
def __init__(self, embedding_size: int, num_heads: int):
self.n_heads, self.embedding_size = num_heads, embedding_size
# self.fc_q = torch.nn.linear(query_size, self.n_heads * self.embedding_size)
# self.fc_k = torch.nn.linear(key_size, self.n_heads * self.embedding_size)
# self.fc_v = torch.nn.linear(value_size, self.n_heads * self.embedding_size)
# self.fc_q = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads)
# self.fc_k = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads)
# self.fc_v = torch.nn.Linear(self.embedding_size // self.n_heads, self.embedding_size // self.n_heads)
# self.fc_qk = torch.nn.Linear(self.embedding_size, 2 * self.embedding_size)
self.fc_q = torch.nn.Linear(self.embedding_size, self.embedding_size)
self.fc_k = torch.nn.Linear(self.embedding_size, self.embedding_size)
self.fc_v = torch.nn.Linear(self.embedding_size, self.embedding_size)
self.fc_out = torch.nn.Linear(self.embedding_size, self.embedding_size)
for layer in [self.fc_q, self.fc_k, self.fc_v, self.fc_out]:
# for layer in [self.fc_qk, self.fc_v, self.fc_out]:
torch.torch.nn.init.normal_(
layer.weight, std=(0.125 / embedding_size) ** 0.5
)
self.embedding_norm = torch.nn.LayerNorm(embedding_size)
self.n_heads = num_heads
self.head_size: int = embedding_size // self.n_heads
self.embedding_size: int = self.head_size * self.n_heads
inp: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
n_q: int,
n_k: int,
number_of_keys: int = -1,
number_of_queries: int = -1,
# This is to avoid using .size() when possible as Barracuda does not support
n_q = number_of_queries if number_of_queries != -1 else query.size(1)
n_k = number_of_keys if number_of_keys != -1 else key.size(1)
inp = self.embedding_norm(inp)
query = self.fc_q(inp)
query = query.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads)
key = self.fc_k(inp)
key = key.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads)
query = query.reshape(
b, n_q, self.n_heads, self.head_size
) # (b, n_q, h, emb / h)
key = key.reshape(b, n_k, self.n_heads, self.head_size) # (b, n_k, h, emb / h)
value = value.reshape(
b, n_k, self.n_heads, self.head_size
) # (b, n_k, h, emb / h)
# query = self.fc_q(inp)
# qk = self.fc_qk(inp)
# qk = qk.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads, 2)
# query, key = torch.split(qk, 1, dim=-1)
# query = torch.squeeze(query, dim=-1)
# key = torch.squeeze(key, dim=-1)
value = self.fc_v(inp) # (b, n_k, h*d)
value = value.reshape(b, n_k, self.n_heads, self.embedding_size // self.n_heads)
# query = query.reshape(b, n_q, self.n_heads, self.embedding_size // self.n_heads)
# key = key.reshape(b, n_k, self.n_heads, self.embedding_size // self.n_heads)
# value = value.reshape(b, n_k, self.n_heads, self.embedding_size // self.n_heads)
# query = self.fc_q(query) # (b, n_q, h*d)
# key = self.fc_k(key) # (b, n_k, h*d)
# value = self.fc_v(value) # (b, n_k, h*d)
query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb)
query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb / h)
key = key.permute([0, 2, 1, 3]) # (b, h, emb, n_k)
key = key.permute([0, 2, 1, 3]) # (b, h, emb / h, n_k)
key = key.permute([0, 1, 3, 2]) # (b, h, emb, n_k)
key = key.permute([0, 1, 3, 2]) # (b, h, emb / h, n_k)
qk = torch.matmul(query, key) # (b, h, n_q, n_k)

att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k)
value = value.permute([0, 2, 1, 3]) # (b, h, n_k, emb)
value_attention = torch.matmul(att, value) # (b, h, n_q, emb)
value = value.permute([0, 2, 1, 3]) # (b, h, n_k, emb / h)
value_attention = torch.matmul(att, value) # (b, h, n_q, emb / h)
value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb)
value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb / h)
) # (b, n_q, h*emb)
out = self.fc_out(value_attention) # (b, n_q, emb)
# if out.requires_grad:
# out.register_hook(lambda x: print(x))
# out = self.out_norm(out)
return out, att
) # (b, n_q, emb)
return value_attention, att
class SimpleTransformer(torch.nn.Module):
class EntityEmbeddings(torch.nn.Module):
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses
multi head self attention to encode information about a "Self" and a list of
relevant "Entities".
EPISLON = 1e-7
self, x_self_size: int, entities_sizes: List[int], embedding_size: int
self,
x_self_size: int,
entity_sizes: List[int],
entity_num_max_elements: List[int],
embedding_size: int,
concat_self: bool = True,
self.self_size = x_self_size
self.entities_sizes = entities_sizes
self.entities_num_max_elements: Optional[List[int]] = None
self.self_size: int = x_self_size
self.entity_sizes: List[int] = entity_sizes
self.entity_num_max_elements: List[int] = entity_num_max_elements
self.concat_self: bool = concat_self
self.embedding_norm = LayerNorm()
# If not concatenating self, input to encoder is just entity size
if not concat_self:
self.self_size = 0
# LinearEncoder(self.self_size + ent_size, 2, embedding_size)
# from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
# linear_layer(self.self_size + ent_size, embedding_size, Initialization.Normal, kernel_gain=(.125 / (self.self_size + ent_size)) ** 0.5)
# linear_layer(ent_size, embedding_size, Initialization.Normal, kernel_gain=(.125 / (self.self_size + ent_size)) ** 0.5)
ent_size,
self.self_size + ent_size,
# LinearEncoder(self.self_size + ent_size, 1, embedding_size, layer_norm=False)
for ent_size in self.entities_sizes
for ent_size in self.entity_sizes
self.attention = MultiHeadAttention(num_heads=4, embedding_size=embedding_size)
# self.residual_layer = torch.nn.Linear(
# embedding_size, embedding_size
# )
# torch.torch.nn.init.normal_(self.residual_layer.weight, std = 0.125 * embedding_size ** -0.5)
self.res_norm = torch.nn.LayerNorm(embedding_size)
self,
x_self: torch.Tensor,
entities: List[torch.Tensor],
key_masks: List[torch.Tensor],
) -> torch.Tensor:
# Gather the maximum number of entities information
if self.entities_num_max_elements is None:
self.entities_num_max_elements = []
for ent in entities:
self.entities_num_max_elements.append(ent.shape[1])
# Concatenate all observations with self
# self_and_ent: List[torch.Tensor] = []
# for num_entities, ent in zip(self.entities_num_max_elements, entities):
# expanded_self = x_self.reshape(-1, 1, self.self_size)
# # .repeat(
# # 1, num_entities, 1
# # )
# expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
# self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
# Generate the tensor that will serve as query, key and value to self attention
qkv = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, entities)],
self, x_self: torch.Tensor, entities: List[torch.Tensor]
) -> Tuple[torch.Tensor, int]:
if self.concat_self:
# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
else:
self_and_ent = entities
# Encode and concatenate entites
encoded_entities = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
mask = torch.cat(key_masks, dim=1)
# Feed to self attention
max_num_ent = sum(self.entities_num_max_elements)
output, _ = self.attention(qkv, mask, max_num_ent, max_num_ent)
# residual 1
output += qkv
output = self.res_norm(output)
# residual 2
# output = self.residual_layer(output) + output #qkv
# average pooling
numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1), dim=1)
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON
output = numerator / denominator
# Residual between x_self and the output of the module
output = torch.cat([output, x_self], dim=1)
return output
encoded_entities = self.embedding_norm(encoded_entities)
return encoded_entities
@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:

for ent in observations
]
return key_masks
class ResidualSelfAttention(torch.nn.Module):
"""
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses
multi head self attention to encode information about a "Self" and a list of
relevant "Entities".
"""
EPSILON = 1e-7
def __init__(
self,
embedding_size: int,
entity_num_max_elements: List[int],
num_heads: int = 4,
):
super().__init__()
self.entity_num_max_elements: List[int] = entity_num_max_elements
self.max_num_ent = sum(entity_num_max_elements)
self.attention = MultiHeadAttention(
num_heads=num_heads, embedding_size=embedding_size
)
self.residual_norm = LayerNorm() # torch.nn.LayerNorm(embedding_size)
self.fc_q = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.fc_k = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.fc_v = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.fc_out = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor:
# Gather the maximum number of entities information
mask = torch.cat(key_masks, dim=1)
# Feed to self attention
query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)
value = self.fc_v(inp) # (b, n_k, emb)
output, _ = self.attention(
query, key, value, self.max_num_ent, self.max_num_ent, mask
)
# Residual
output = self.fc_out(output) + inp
output = self.residual_norm(output)
# Average Pooling
numerator = torch.sum(
output * (1 - mask).reshape(-1, self.max_num_ent, 1), dim=1
)
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
output = numerator / denominator
# Residual between x_self and the output of the module
return output

14
ml-agents/mlagents/trainers/torch/layers.py


class LayerNorm(torch.nn.Module):
def __init__(self, input_size: int, elementwise_affine: bool = False):
super().__init__()
self.gamma = torch.nn.Parameter(
torch.ones(input_size, requires_grad=elementwise_affine)
)
self.beta = torch.nn.Parameter(
torch.zeros(input_size, requires_grad=elementwise_affine)
)
centered_activations = layer_activations - mean
var = torch.mean(centered_activations ** 2, dim=-1, keepdim=True)
return centered_activations / (torch.sqrt(var + 1e-5)) * self.gamma + self.beta
var = torch.mean((layer_activations - mean) ** 2, dim=-1, keepdim=True)
return (layer_activations - mean) / (torch.sqrt(var + 1e-5))
class MemoryModule(torch.nn.Module):

21
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.attention import SimpleTransformer
from mlagents.trainers.torch.attention import ResidualSelfAttention, EntityEmbeddings
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]

# self.h_size,
# self.h_size
# )
self.transformer = SimpleTransformer(
x_self_len, entities_sizes, self.n_embd
self.entity_embedding = EntityEmbeddings(
x_self_len, entities_sizes, [20], self.n_embd, concat_self=False
# self.embedding_norm = torch.nn.LayerNorm(self.n_embd)
self.transformer = ResidualSelfAttention(self.n_embd, [20])
# self.transformer = SmallestAttention(x_self_len, entities_sizes, self.h_size, self.h_size)
# self.transformer = SmallestAttention(64, [64], self.h_size, self.h_size)
# self.use_fc = True

raise Exception("No valid inputs to network.")
for _, tens in list(self.transformer.named_parameters()):
tens.retain_grad()
for _, tens in list(self.entity_embedding.named_parameters()):
tens.retain_grad()
# for _, tens in list(self.embedding_norm.named_parameters()):
# tens.retain_grad()
total_enc_size += encoded_act_size
self.linear_encoder = LinearEncoder(total_enc_size, n_layers, self.h_size)
for _, tens in list(self.linear_encoder.named_parameters()):

x_self_encoded = x_self
# x_self_encoded = self.x_self_enc(x_self)
embedded_entities = self.entity_embedding(x_self_encoded, var_len_inputs)
# embedded_entities = self.embedding_norm(embedded_entities)
x_self_encoded,
var_len_inputs,
SimpleTransformer.get_masks(var_len_inputs),
embedded_entities, EntityEmbeddings.get_masks(var_len_inputs)
encoded_state = torch.cat([x_self_encoded, encoded_state], dim=1)
# print("\n\n\nUsing transformer ", self.transformer, "use fc = ", self.use_fc, " x_self.shape=",x_self_encoded.shape," var_len_inputs[0].shape=",var_len_inputs[0].shape," len(var_len_inputs)=",len(var_len_inputs))
else:
encoded_state = torch.cat(encodes, dim=1)

正在加载...
取消
保存