浏览代码

Making EntityEmbedding encode self and entities separately

/develop/singular-embeddings
vincentpierre 4 年前
当前提交
65c2fa7f
共有 1 个文件被更改,包括 24 次插入5 次删除
  1. 29
      ml-agents/mlagents/trainers/torch/attention.py

29
ml-agents/mlagents/trainers/torch/attention.py


"""
Constructs an EntityEmbedding module.
:param x_self_size: Size of "self" entity.
:param entity_size: Size of other entitiy.
:param entity_size: Size of other entities.
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
Needs to be assigned in order for model to be exportable to ONNX and Barracuda.
:param embedding_size: Embedding size for the entity encoder.

if not concat_self:
self.self_size = 0
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.ent_encoder = LinearEncoder(
self.self_size + self.entity_size,
if self.self_size > 0:
self.self_encoder = LinearEncoder(
self.self_size,
1,
embedding_size // 2,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.ent_encoder = LinearEncoder(
self.entity_size,
1,
embedding_size - (embedding_size // 2),
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.self_ent_encoder = LinearEncoder(
embedding_size if self.self_size > 0 else self.entity_size,
1,
embedding_size,
kernel_init=Initialization.Normal,

expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
# Concatenate all observations with self
entities = torch.cat([expanded_self, entities], dim=2)
entities = torch.cat(
[self.self_encoder(expanded_self), self.ent_encoder(entities)], dim=2
)
encoded_entities = self.ent_encoder(entities)
encoded_entities = self.self_ent_encoder(entities)
return encoded_entities

正在加载...
取消
保存