浏览代码

addressing comments

/develop/singular-embeddings
vincentpierre 4 年前
当前提交
7e47f94b
共有 4 个文件被更改,包括 20 次插入11 次删除
  1. 2
      ml-agents/mlagents/trainers/tests/torch/test_attention.py
  2. 9
      ml-agents/mlagents/trainers/torch/attention.py
  3. 12
      ml-agents/mlagents/trainers/torch/networks.py
  4. 8
      ml-agents/mlagents/trainers/torch/utils.py

2
ml-agents/mlagents/trainers/tests/torch/test_attention.py


torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbedding(size, size, n_k, embedding_size)
entity_embeddings = EntityEmbedding(size, size, n_k, embedding_size, True)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(

9
ml-agents/mlagents/trainers/torch/attention.py


class EntityEmbedding(torch.nn.Module):
"""
A module used to embed entities before passing them to a self-attention block.
Used in conjunction with ResidualSelfAttention to encode information about a self
and additional entities. Can also concatenate self to entities for ego-centric self-
attention. Inspired by architecture used in https://arxiv.org/pdf/1909.07528.pdf.
"""
def __init__(
self,
x_self_size: int,

concat_self: bool = True,
concat_self: bool,
):
"""
Constructs an EntityEmbedding module.

12
ml-agents/mlagents/trainers/torch/networks.py


encodes.append(processed_obs)
else:
var_len_inputs.append(inputs[idx])
if len(encodes) == 0:
encoded_self = torch.zeros(0, 0)
else:
if len(encodes) != 0:
input_exist = True
else:
input_exist = False
if len(var_len_inputs) > 0:
# Some inputs need to be processed with a variable length encoder
masks = get_zero_entities_mask(var_len_inputs)

embeddings.append(var_len_processor(encoded_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
attention_embedding = self.rsa(qkv, masks)
if encoded_self.shape[1] == 0:
if not input_exist:
input_exist = True
if encoded_self.shape[1] == 0:
if not input_exist:
raise Exception("No valid inputs to network.")
# Constants don't work in Barracuda

8
ml-agents/mlagents/trainers/torch/utils.py


for idx in var_len_indices:
var_encoders.append(
EntityEmbedding(
x_self_size,
observation_specs[idx].shape[1],
observation_specs[idx].shape[0],
h_size,
x_self_size=x_self_size,
entity_size=observation_specs[idx].shape[1],
entity_num_max_elements=observation_specs[idx].shape[0],
embedding_size=h_size,
concat_self=True,
)
)

正在加载...
取消
保存