浏览代码

Merge branch 'develop-attention-refactor' into develop-centralizedcritic-mm

/develop/centralizedcritic
Ervin Teng 4 年前
当前提交
30a09c6f
共有 2 个文件被更改,包括 28 次插入61 次删除
  1. 85
      ml-agents/mlagents/trainers/tests/torch/test_attention.py
  2. 4
      ml-agents/mlagents/trainers/torch/attention.py

85
ml-agents/mlagents/trainers/tests/torch/test_attention.py


import numpy as np
from mlagents.trainers.torch.layers import linear_layer
from mlagents.trainers.torch.attention import MultiHeadAttention, SimpleTransformer
from mlagents.trainers.torch.attention import (
MultiHeadAttention,
EntityEmbeddings,
ResidualSelfAttention,
)
q_size, k_size, v_size, o_size, n_h, emb_size = 7, 8, 9, 10, 11, 12
n_h, emb_size = 4, 12
mha = MultiHeadAttention(q_size, k_size, v_size, o_size, n_h, emb_size)
mha = MultiHeadAttention(emb_size, n_h)
query = torch.ones((b, n_q, q_size))
key = torch.ones((b, n_k, k_size))
value = torch.ones((b, n_k, v_size))
query = torch.ones((b, n_q, emb_size))
key = torch.ones((b, n_k, emb_size))
value = torch.ones((b, n_k, emb_size))
output, attention = mha.forward(query, key, value)
output, attention = mha.forward(query, key, value, n_q, n_k)
assert output.shape == (b, n_q, o_size)
assert output.shape == (b, n_q, emb_size)
q_size, k_size, v_size, o_size, n_h, emb_size = 7, 8, 9, 10, 11, 12
n_h, emb_size = 4, 12
mha = MultiHeadAttention(q_size, k_size, v_size, o_size, n_h, emb_size)
mha = MultiHeadAttention(emb_size, n_h)
# create a key input with some keys all 0
query = torch.ones((b, n_q, emb_size))
key = torch.ones((b, n_k, emb_size))
value = torch.ones((b, n_k, emb_size))
# create a key input with some keys all 0
key = torch.ones((b, n_k, k_size))
mask = torch.zeros((b, n_k))
for i in range(n_k):
if i % 3 == 0:

query = torch.ones((b, n_q, q_size))
value = torch.ones((b, n_k, v_size))
_, attention = mha.forward(query, key, value, n_q, n_k, mask)
_, attention = mha.forward(query, key, value, mask)
for i in range(n_k):
if i % 3 == 0:
assert torch.sum(attention[:, :, :, i] ** 2) < epsilon

def test_multi_head_attention_training():
np.random.seed(1336)
torch.manual_seed(1336)
size, n_h, n_k, n_q = 3, 10, 5, 1
embedding_size = 64
mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size)
optimizer = torch.optim.Adam(mha.parameters(), lr=0.001)
batch_size = 200
point_range = 3
init_error = -1.0
for _ in range(50):
query = torch.rand((batch_size, n_q, size)) * point_range * 2 - point_range
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range
value = key
with torch.no_grad():
# create the target : The key closest to the query in euclidean distance
distance = torch.sum((query - key) ** 2, dim=2)
argmin = torch.argmin(distance, dim=1)
target = []
for i in range(batch_size):
target += [key[i, argmin[i], :]]
target = torch.stack(target, dim=0)
target = target.detach()
prediction, _ = mha.forward(query, key, value)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
if init_error == -1.0:
init_error = error.item()
else:
assert error.item() < init_error
print(error.item())
optimizer.zero_grad()
error.backward()
optimizer.step()
assert error.item() < 0.5
def test_zero_mask_layer():
batch_size, size = 10, 30

input_1 = generate_input_helper(masking_pattern_1)
input_2 = generate_input_helper(masking_pattern_2)
masks = SimpleTransformer.get_masks([input_1, input_2])
masks = EntityEmbeddings.get_masks([input_1, input_2])
assert len(masks) == 2
masks_1 = masks[0]
masks_2 = masks[1]

torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
transformer = SimpleTransformer(size, [size], embedding_size)
entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size)
transformer = ResidualSelfAttention(embedding_size, [n_k])
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001

init_error = -1.0
for _ in range(100):
for _ in range(250):
center = torch.rand((batch_size, size)) * point_range * 2 - point_range
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range
with torch.no_grad():

target = torch.stack(target, dim=0)
target = target.detach()
masks = SimpleTransformer.get_masks([key])
prediction = transformer.forward(center, [key], masks)
embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)

4
ml-agents/mlagents/trainers/torch/attention.py


self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(self.self_size + ent_size, 2, embedding_size)
for ent_size in self.entities_sizes
for ent_size in self.entity_sizes
]
)

if self.concat_self:
# Concatenate all observations with self
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entities_num_max_elements, entities):
for num_entities, ent in zip(self.entity_num_max_elements, entities):
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))

正在加载...
取消
保存