浏览代码
Merge pull request #4763 from Unity-Technologies/develop-att
Merge pull request #4763 from Unity-Technologies/develop-att
WIP Made initial changes to enable dimension properties and added attention module/MLA-1734-demo-provider
GitHub
4 年前
当前提交
458fee17
共有 56 个文件被更改,包括 956 次插入 和 217 次删除
-
11com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
-
33com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
-
9docs/Python-API.md
-
16gym-unity/gym_unity/envs/__init__.py
-
4gym-unity/gym_unity/tests/test_gym.py
-
55ml-agents-envs/mlagents_envs/base_env.py
-
19ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
-
6ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
-
24ml-agents-envs/mlagents_envs/rpc_utils.py
-
20ml-agents-envs/mlagents_envs/tests/test_envs.py
-
2ml-agents-envs/mlagents_envs/tests/test_registry.py
-
31ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
-
7ml-agents-envs/mlagents_envs/tests/test_steps.py
-
11ml-agents/mlagents/trainers/demo_loader.py
-
2ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
-
6ml-agents/mlagents/trainers/policy/policy.py
-
2ml-agents/mlagents/trainers/policy/torch_policy.py
-
2ml-agents/mlagents/trainers/ppo/optimizer_torch.py
-
18ml-agents/mlagents/trainers/sac/optimizer_torch.py
-
3ml-agents/mlagents/trainers/tests/check_env_trains.py
-
11ml-agents/mlagents/trainers/tests/dummy_config.py
-
32ml-agents/mlagents/trainers/tests/mock_brain.py
-
15ml-agents/mlagents/trainers/tests/simple_test_envs.py
-
16ml-agents/mlagents/trainers/tests/test_agent_processor.py
-
4ml-agents/mlagents/trainers/tests/test_demo_loader.py
-
6ml-agents/mlagents/trainers/tests/test_rl_trainer.py
-
4ml-agents/mlagents/trainers/tests/test_trajectory.py
-
3ml-agents/mlagents/trainers/tests/torch/test_ghost.py
-
12ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
-
23ml-agents/mlagents/trainers/tests/torch/test_networks.py
-
4ml-agents/mlagents/trainers/tests/torch/test_policy.py
-
2ml-agents/mlagents/trainers/tests/torch/test_ppo.py
-
40ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
-
13ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
-
27ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
-
28ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py
-
8ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
-
2ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
-
4ml-agents/mlagents/trainers/tests/torch/test_utils.py
-
2ml-agents/mlagents/trainers/torch/components/bc/module.py
-
4ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
-
4ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
-
2ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py
-
10ml-agents/mlagents/trainers/torch/model_serialization.py
-
32ml-agents/mlagents/trainers/torch/networks.py
-
10ml-agents/mlagents/trainers/torch/utils.py
-
4ml-agents/tests/yamato/scripts/run_llapi.py
-
1protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
-
95com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs
-
11com.unity.ml-agents/Runtime/Sensors/BufferSensor.cs.meta
-
41com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs
-
11com.unity.ml-agents/Runtime/Sensors/BufferSensorComponent.cs.meta
-
47com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs
-
11com.unity.ml-agents/Runtime/Sensors/IDimensionPropertiesSensor.cs.meta
-
162ml-agents/mlagents/trainers/tests/torch/test_attention.py
-
191ml-agents/mlagents/trainers/torch/attention.py
|
|||
using System; |
|||
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
public class BufferSensor : ISensor, IDimensionPropertiesSensor |
|||
{ |
|||
private int m_MaxNumObs; |
|||
private int m_ObsSize; |
|||
float[] m_ObservationBuffer; |
|||
int m_CurrentNumObservables; |
|||
public BufferSensor(int maxNumberObs, int obsSize) |
|||
{ |
|||
m_MaxNumObs = maxNumberObs; |
|||
m_ObsSize = obsSize; |
|||
m_ObservationBuffer = new float[m_ObsSize * m_MaxNumObs]; |
|||
m_CurrentNumObservables = 0; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public int[] GetObservationShape() |
|||
{ |
|||
return new int[] { m_MaxNumObs, m_ObsSize }; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public DimensionProperty[] GetDimensionProperties() |
|||
{ |
|||
return new DimensionProperty[]{ |
|||
DimensionProperty.VariableSize, |
|||
DimensionProperty.None |
|||
}; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Appends an observation to the buffer. If the buffer is full (maximum number
|
|||
/// of observation is reached) the observation will be ignored. the length of
|
|||
/// the provided observation array must be equal to the observation size of
|
|||
/// the buffer sensor.
|
|||
/// </summary>
|
|||
/// <param name="obs"> The float array observation</param>
|
|||
public void AppendObservation(float[] obs) |
|||
{ |
|||
if (m_CurrentNumObservables >= m_MaxNumObs) |
|||
{ |
|||
return; |
|||
} |
|||
for (int i = 0; i < obs.Length; i++) |
|||
{ |
|||
m_ObservationBuffer[m_CurrentNumObservables * m_ObsSize + i] = obs[i]; |
|||
} |
|||
m_CurrentNumObservables++; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public int Write(ObservationWriter writer) |
|||
{ |
|||
for (int i = 0; i < m_ObsSize * m_MaxNumObs; i++) |
|||
{ |
|||
writer[i] = m_ObservationBuffer[i]; |
|||
} |
|||
return m_ObsSize * m_MaxNumObs; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public virtual byte[] GetCompressedObservation() |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Update() |
|||
{ |
|||
Reset(); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Reset() |
|||
{ |
|||
m_CurrentNumObservables = 0; |
|||
Array.Clear(m_ObservationBuffer, 0, m_ObservationBuffer.Length); |
|||
} |
|||
|
|||
public SensorCompressionType GetCompressionType() |
|||
{ |
|||
return SensorCompressionType.None; |
|||
} |
|||
|
|||
public string GetName() |
|||
{ |
|||
return "BufferSensor"; |
|||
} |
|||
|
|||
} |
|||
|
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 034f05c858e684e5498d9a548c9d1fc5 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// A component for BufferSensor.
|
|||
/// </summary>
|
|||
[AddComponentMenu("ML Agents/Buffer Sensor", (int)MenuGroup.Sensors)] |
|||
public class BufferSensorComponent : SensorComponent |
|||
{ |
|||
public int ObservableSize; |
|||
public int MaxNumObservables; |
|||
private BufferSensor m_Sensor; |
|||
|
|||
/// <inheritdoc/>
|
|||
public override ISensor CreateSensor() |
|||
{ |
|||
m_Sensor = new BufferSensor(MaxNumObservables, ObservableSize); |
|||
return m_Sensor; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override int[] GetObservationShape() |
|||
{ |
|||
return new[] { MaxNumObservables, ObservableSize }; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Appends an observation to the buffer. If the buffer is full (maximum number
|
|||
/// of observation is reached) the observation will be ignored. the length of
|
|||
/// the provided observation array must be equal to the observation size of
|
|||
/// the buffer sensor.
|
|||
/// </summary>
|
|||
/// <param name="obs"> The float array observation</param>
|
|||
public void AppendObservation(float[] obs) |
|||
{ |
|||
m_Sensor.AppendObservation(obs); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: dd8012d5925524537b27131fef517017 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
|
|||
/// <summary>
|
|||
/// The Dimension property flags of the observations
|
|||
/// </summary>
|
|||
[System.Flags] |
|||
public enum DimensionProperty |
|||
{ |
|||
/// <summary>
|
|||
/// No properties specified.
|
|||
/// </summary>
|
|||
Unspecified = 0, |
|||
|
|||
/// <summary>
|
|||
/// No Property of the observation in that dimension. Observation can be processed with
|
|||
/// fully connected networks.
|
|||
/// </summary>
|
|||
None = 1, |
|||
|
|||
/// <summary>
|
|||
/// Means it is suitable to do a convolution in this dimension.
|
|||
/// </summary>
|
|||
TranslationalEquivariance = 2, |
|||
|
|||
/// <summary>
|
|||
/// Means that there can be a variable number of observations in this dimension.
|
|||
/// The observations are unordered.
|
|||
/// </summary>
|
|||
VariableSize = 4, |
|||
} |
|||
|
|||
|
|||
/// <summary>
|
|||
/// Sensor interface for sensors with special dimension properties.
|
|||
/// </summary>
|
|||
public interface IDimensionPropertiesSensor |
|||
{ |
|||
/// <summary>
|
|||
/// Returns the array containing the properties of each dimensions of the
|
|||
/// observation. The length of the array must be equal to the rank of the
|
|||
/// observation tensor.
|
|||
/// </summary>
|
|||
/// <returns>The array of DimensionProperty</returns>
|
|||
DimensionProperty[] GetDimensionProperties(); |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 297e9ec12d6de45adbcf6dea1a9de019 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
from mlagents.torch_utils import torch |
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.torch.layers import linear_layer |
|||
from mlagents.trainers.torch.attention import MultiHeadAttention, SimpleTransformer |
|||
|
|||
|
|||
def test_multi_head_attention_initialization(): |
|||
q_size, k_size, v_size, o_size, n_h, emb_size = 7, 8, 9, 10, 11, 12 |
|||
n_k, n_q, b = 13, 14, 15 |
|||
mha = MultiHeadAttention(q_size, k_size, v_size, o_size, n_h, emb_size) |
|||
|
|||
query = torch.ones((b, n_q, q_size)) |
|||
key = torch.ones((b, n_k, k_size)) |
|||
value = torch.ones((b, n_k, v_size)) |
|||
|
|||
output, attention = mha.forward(query, key, value) |
|||
|
|||
assert output.shape == (b, n_q, o_size) |
|||
assert attention.shape == (b, n_h, n_q, n_k) |
|||
|
|||
|
|||
def test_multi_head_attention_masking(): |
|||
epsilon = 0.0001 |
|||
q_size, k_size, v_size, o_size, n_h, emb_size = 7, 8, 9, 10, 11, 12 |
|||
n_k, n_q, b = 13, 14, 15 |
|||
mha = MultiHeadAttention(q_size, k_size, v_size, o_size, n_h, emb_size) |
|||
|
|||
# create a key input with some keys all 0 |
|||
key = torch.ones((b, n_k, k_size)) |
|||
mask = torch.zeros((b, n_k)) |
|||
for i in range(n_k): |
|||
if i % 3 == 0: |
|||
key[:, i, :] = 0 |
|||
mask[:, i] = 1 |
|||
|
|||
query = torch.ones((b, n_q, q_size)) |
|||
value = torch.ones((b, n_k, v_size)) |
|||
|
|||
_, attention = mha.forward(query, key, value, mask) |
|||
for i in range(n_k): |
|||
if i % 3 == 0: |
|||
assert torch.sum(attention[:, :, :, i] ** 2) < epsilon |
|||
else: |
|||
assert torch.sum(attention[:, :, :, i] ** 2) > epsilon |
|||
|
|||
|
|||
def test_multi_head_attention_training(): |
|||
np.random.seed(1336) |
|||
torch.manual_seed(1336) |
|||
size, n_h, n_k, n_q = 3, 10, 5, 1 |
|||
embedding_size = 64 |
|||
mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size) |
|||
optimizer = torch.optim.Adam(mha.parameters(), lr=0.001) |
|||
batch_size = 200 |
|||
point_range = 3 |
|||
init_error = -1.0 |
|||
for _ in range(50): |
|||
query = torch.rand((batch_size, n_q, size)) * point_range * 2 - point_range |
|||
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range |
|||
value = key |
|||
with torch.no_grad(): |
|||
# create the target : The key closest to the query in euclidean distance |
|||
distance = torch.sum((query - key) ** 2, dim=2) |
|||
argmin = torch.argmin(distance, dim=1) |
|||
target = [] |
|||
for i in range(batch_size): |
|||
target += [key[i, argmin[i], :]] |
|||
target = torch.stack(target, dim=0) |
|||
target = target.detach() |
|||
|
|||
prediction, _ = mha.forward(query, key, value) |
|||
prediction = prediction.reshape((batch_size, size)) |
|||
error = torch.mean((prediction - target) ** 2, dim=1) |
|||
error = torch.mean(error) / 2 |
|||
if init_error == -1.0: |
|||
init_error = error.item() |
|||
else: |
|||
assert error.item() < init_error |
|||
print(error.item()) |
|||
optimizer.zero_grad() |
|||
error.backward() |
|||
optimizer.step() |
|||
assert error.item() < 0.5 |
|||
|
|||
|
|||
def test_zero_mask_layer(): |
|||
batch_size, size = 10, 30 |
|||
|
|||
def generate_input_helper(pattern): |
|||
_input = torch.zeros((batch_size, 0, size)) |
|||
for i in range(len(pattern)): |
|||
if i % 2 == 0: |
|||
_input = torch.cat( |
|||
[_input, torch.rand((batch_size, pattern[i], size))], dim=1 |
|||
) |
|||
else: |
|||
_input = torch.cat( |
|||
[_input, torch.zeros((batch_size, pattern[i], size))], dim=1 |
|||
) |
|||
return _input |
|||
|
|||
masking_pattern_1 = [3, 2, 3, 4] |
|||
masking_pattern_2 = [5, 7, 8, 2] |
|||
input_1 = generate_input_helper(masking_pattern_1) |
|||
input_2 = generate_input_helper(masking_pattern_2) |
|||
|
|||
masks = SimpleTransformer.get_masks([input_1, input_2]) |
|||
assert len(masks) == 2 |
|||
masks_1 = masks[0] |
|||
masks_2 = masks[1] |
|||
assert masks_1.shape == (batch_size, sum(masking_pattern_1)) |
|||
assert masks_2.shape == (batch_size, sum(masking_pattern_2)) |
|||
for i in masking_pattern_1: |
|||
assert masks_1[0, 1] == 0 if i % 2 == 0 else 1 |
|||
for i in masking_pattern_2: |
|||
assert masks_2[0, 1] == 0 if i % 2 == 0 else 1 |
|||
|
|||
|
|||
def test_simple_transformer_training(): |
|||
np.random.seed(1336) |
|||
torch.manual_seed(1336) |
|||
size, n_k, = 3, 5 |
|||
embedding_size = 64 |
|||
transformer = SimpleTransformer(size, [size], embedding_size) |
|||
l_layer = linear_layer(embedding_size, size) |
|||
optimizer = torch.optim.Adam( |
|||
list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001 |
|||
) |
|||
batch_size = 200 |
|||
point_range = 3 |
|||
init_error = -1.0 |
|||
for _ in range(100): |
|||
center = torch.rand((batch_size, size)) * point_range * 2 - point_range |
|||
key = torch.rand((batch_size, n_k, size)) * point_range * 2 - point_range |
|||
with torch.no_grad(): |
|||
# create the target : The key closest to the query in euclidean distance |
|||
distance = torch.sum( |
|||
(center.reshape((batch_size, 1, size)) - key) ** 2, dim=2 |
|||
) |
|||
argmin = torch.argmin(distance, dim=1) |
|||
target = [] |
|||
for i in range(batch_size): |
|||
target += [key[i, argmin[i], :]] |
|||
target = torch.stack(target, dim=0) |
|||
target = target.detach() |
|||
|
|||
masks = SimpleTransformer.get_masks([key]) |
|||
prediction = transformer.forward(center, [key], masks) |
|||
prediction = l_layer(prediction) |
|||
prediction = prediction.reshape((batch_size, size)) |
|||
error = torch.mean((prediction - target) ** 2, dim=1) |
|||
error = torch.mean(error) / 2 |
|||
if init_error == -1.0: |
|||
init_error = error.item() |
|||
else: |
|||
assert error.item() < init_error |
|||
print(error.item()) |
|||
optimizer.zero_grad() |
|||
error.backward() |
|||
optimizer.step() |
|||
assert error.item() < 0.3 |
|
|||
from mlagents.torch_utils import torch |
|||
from typing import Tuple, Optional, List |
|||
from mlagents.trainers.torch.layers import LinearEncoder |
|||
|
|||
|
|||
class MultiHeadAttention(torch.nn.Module): |
|||
""" |
|||
Multi Head Attention module. We do not use the regular Torch implementation since |
|||
Barracuda does not support some operators it uses. |
|||
Takes as input to the forward method 3 tensors: |
|||
- query: of dimensions (batch_size, number_of_queries, key_size) |
|||
- key: of dimensions (batch_size, number_of_keys, key_size) |
|||
- value: of dimensions (batch_size, number_of_keys, value_size) |
|||
The forward method will return 2 tensors: |
|||
- The output: (batch_size, number_of_queries, output_size) |
|||
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) |
|||
""" |
|||
|
|||
NEG_INF = -1e6 |
|||
|
|||
def __init__( |
|||
self, |
|||
query_size: int, |
|||
key_size: int, |
|||
value_size: int, |
|||
output_size: int, |
|||
num_heads: int, |
|||
embedding_size: int, |
|||
): |
|||
super().__init__() |
|||
self.n_heads, self.embedding_size = num_heads, embedding_size |
|||
self.output_size = output_size |
|||
self.fc_q = torch.nn.Linear(query_size, self.n_heads * self.embedding_size) |
|||
self.fc_k = torch.nn.Linear(key_size, self.n_heads * self.embedding_size) |
|||
self.fc_v = torch.nn.Linear(value_size, self.n_heads * self.embedding_size) |
|||
# self.fc_q = LinearEncoder(query_size, 2, self.n_heads * self.embedding_size) |
|||
# self.fc_k = LinearEncoder(key_size,2, self.n_heads * self.embedding_size) |
|||
# self.fc_v = LinearEncoder(value_size,2, self.n_heads * self.embedding_size) |
|||
self.fc_out = torch.nn.Linear( |
|||
self.n_heads * self.embedding_size, self.output_size |
|||
) |
|||
|
|||
def forward( |
|||
self, |
|||
query: torch.Tensor, |
|||
key: torch.Tensor, |
|||
value: torch.Tensor, |
|||
key_mask: Optional[torch.Tensor] = None, |
|||
number_of_keys: int = -1, |
|||
number_of_queries: int = -1, |
|||
) -> Tuple[torch.Tensor, torch.Tensor]: |
|||
b = -1 # the batch size |
|||
# This is to avoid using .size() when possible as Barracuda does not support |
|||
n_q = number_of_queries if number_of_queries != -1 else query.size(1) |
|||
n_k = number_of_keys if number_of_keys != -1 else key.size(1) |
|||
|
|||
query = self.fc_q(query) # (b, n_q, h*d) |
|||
key = self.fc_k(key) # (b, n_k, h*d) |
|||
value = self.fc_v(value) # (b, n_k, h*d) |
|||
|
|||
query = query.reshape(b, n_q, self.n_heads, self.embedding_size) |
|||
key = key.reshape(b, n_k, self.n_heads, self.embedding_size) |
|||
value = value.reshape(b, n_k, self.n_heads, self.embedding_size) |
|||
|
|||
query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb) |
|||
# The next few lines are equivalent to : key.permute([0, 2, 3, 1]) |
|||
# This is a hack, ONNX will compress two permute operations and |
|||
# Barracuda will not like seeing `permute([0,2,3,1])` |
|||
key = key.permute([0, 2, 1, 3]) # (b, h, emb, n_k) |
|||
key -= 1 |
|||
key += 1 |
|||
key = key.permute([0, 1, 3, 2]) # (b, h, emb, n_k) |
|||
|
|||
qk = torch.matmul(query, key) # (b, h, n_q, n_k) |
|||
|
|||
if key_mask is None: |
|||
qk = qk / (self.embedding_size ** 0.5) |
|||
else: |
|||
key_mask = key_mask.reshape(b, 1, 1, n_k) |
|||
qk = (1 - key_mask) * qk / ( |
|||
self.embedding_size ** 0.5 |
|||
) + key_mask * self.NEG_INF |
|||
|
|||
att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k) |
|||
|
|||
value = value.permute([0, 2, 1, 3]) # (b, h, n_k, emb) |
|||
value_attention = torch.matmul(att, value) # (b, h, n_q, emb) |
|||
|
|||
value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb) |
|||
value_attention = value_attention.reshape( |
|||
b, n_q, self.n_heads * self.embedding_size |
|||
) # (b, n_q, h*emb) |
|||
|
|||
out = self.fc_out(value_attention) # (b, n_q, emb) |
|||
return out, att |
|||
|
|||
|
|||
class SimpleTransformer(torch.nn.Module): |
|||
""" |
|||
A simple architecture inspired from https://arxiv.org/pdf/1909.07528.pdf that uses |
|||
multi head self attention to encode information about a "Self" and a list of |
|||
relevant "Entities". |
|||
""" |
|||
|
|||
EPISLON = 1e-7 |
|||
|
|||
def __init__( |
|||
self, |
|||
x_self_size: int, |
|||
entities_sizes: List[int], |
|||
embedding_size: int, |
|||
output_size: Optional[int] = None, |
|||
): |
|||
super().__init__() |
|||
self.self_size = x_self_size |
|||
self.entities_sizes = entities_sizes |
|||
self.entities_num_max_elements: Optional[List[int]] = None |
|||
self.ent_encoders = torch.nn.ModuleList( |
|||
[ |
|||
LinearEncoder(self.self_size + ent_size, 2, embedding_size) |
|||
for ent_size in self.entities_sizes |
|||
] |
|||
) |
|||
self.attention = MultiHeadAttention( |
|||
query_size=embedding_size, |
|||
key_size=embedding_size, |
|||
value_size=embedding_size, |
|||
output_size=embedding_size, |
|||
num_heads=4, |
|||
embedding_size=embedding_size, |
|||
) |
|||
self.residual_layer = LinearEncoder(embedding_size, 1, embedding_size) |
|||
if output_size is None: |
|||
output_size = embedding_size |
|||
self.x_self_residual_layer = LinearEncoder( |
|||
embedding_size + x_self_size, 1, output_size |
|||
) |
|||
|
|||
def forward( |
|||
self, |
|||
x_self: torch.Tensor, |
|||
entities: List[torch.Tensor], |
|||
key_masks: List[torch.Tensor], |
|||
) -> torch.Tensor: |
|||
# Gather the maximum number of entities information |
|||
if self.entities_num_max_elements is None: |
|||
self.entities_num_max_elements = [] |
|||
for ent in entities: |
|||
self.entities_num_max_elements.append(ent.shape[1]) |
|||
# Concatenate all observations with self |
|||
self_and_ent: List[torch.Tensor] = [] |
|||
for num_entities, ent in zip(self.entities_num_max_elements, entities): |
|||
expanded_self = x_self.reshape(-1, 1, self.self_size) |
|||
# .repeat( |
|||
# 1, num_entities, 1 |
|||
# ) |
|||
expanded_self = torch.cat([expanded_self] * num_entities, dim=1) |
|||
self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) |
|||
# Generate the tensor that will serve as query, key and value to self attention |
|||
qkv = torch.cat( |
|||
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)], |
|||
dim=1, |
|||
) |
|||
mask = torch.cat(key_masks, dim=1) |
|||
# Feed to self attention |
|||
max_num_ent = sum(self.entities_num_max_elements) |
|||
output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent, max_num_ent) |
|||
# Residual |
|||
output = self.residual_layer(output) + qkv |
|||
# Average Pooling |
|||
numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1), dim=1) |
|||
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON |
|||
output = numerator / denominator |
|||
# Residual between x_self and the output of the module |
|||
output = self.x_self_residual_layer(torch.cat([output, x_self], dim=1)) |
|||
return output |
|||
|
|||
@staticmethod |
|||
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]: |
|||
""" |
|||
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was |
|||
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention |
|||
layer to mask the padding observations. |
|||
""" |
|||
with torch.no_grad(): |
|||
# Generate the masking tensors for each entities tensor (mask only if all zeros) |
|||
key_masks: List[torch.Tensor] = [ |
|||
(torch.sum(ent ** 2, axis=2) < 0.01).type(torch.FloatTensor) |
|||
for ent in observations |
|||
] |
|||
return key_masks |
撰写
预览
正在加载...
取消
保存
Reference in new issue