|
|
|
|
|
|
""" |
|
|
|
Constructs an EntityEmbedding module. |
|
|
|
:param x_self_size: Size of "self" entity. |
|
|
|
:param entity_size: Size of other entitiy. |
|
|
|
:param entity_size: Size of other entities. |
|
|
|
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted. |
|
|
|
Needs to be assigned in order for model to be exportable to ONNX and Barracuda. |
|
|
|
:param embedding_size: Embedding size for the entity encoder. |
|
|
|
|
|
|
if not concat_self: |
|
|
|
self.self_size = 0 |
|
|
|
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf |
|
|
|
self.ent_encoder = LinearEncoder( |
|
|
|
self.self_size + self.entity_size, |
|
|
|
if self.self_size > 0: |
|
|
|
self.self_encoder = LinearEncoder( |
|
|
|
self.self_size, |
|
|
|
1, |
|
|
|
embedding_size // 2, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
|
|
|
|
self.ent_encoder = LinearEncoder( |
|
|
|
self.entity_size, |
|
|
|
1, |
|
|
|
embedding_size - (embedding_size // 2), |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
|
|
) |
|
|
|
|
|
|
|
self.self_ent_encoder = LinearEncoder( |
|
|
|
embedding_size if self.self_size > 0 else self.entity_size, |
|
|
|
1, |
|
|
|
embedding_size, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
|
|
|
expanded_self = x_self.reshape(-1, 1, self.self_size) |
|
|
|
expanded_self = torch.cat([expanded_self] * num_entities, dim=1) |
|
|
|
# Concatenate all observations with self |
|
|
|
entities = torch.cat([expanded_self, entities], dim=2) |
|
|
|
entities = torch.cat( |
|
|
|
[self.self_encoder(expanded_self), self.ent_encoder(entities)], dim=2 |
|
|
|
) |
|
|
|
encoded_entities = self.ent_encoder(entities) |
|
|
|
encoded_entities = self.self_ent_encoder(entities) |
|
|
|
return encoded_entities |
|
|
|
|
|
|
|
|
|
|
|