浏览代码

Make a self encoder before EntityEmbedding

/develop/singular-embeddings
vincentpierre 4 年前
当前提交
c27a95f0
共有 5 个文件被更改,包括 32 次插入35 次删除
  1. 2
      ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
  2. 9
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  3. 23
      ml-agents/mlagents/trainers/torch/attention.py
  4. 31
      ml-agents/mlagents/trainers/torch/networks.py
  5. 2
      ml-agents/mlagents/trainers/torch/utils.py

2
ml-agents/mlagents/trainers/tests/torch/test_hybrid.py


PPO_TORCH_CONFIG,
hyperparameters=new_hyperparams,
network_settings=new_network_settings,
max_steps=4000,
max_steps=5000,
)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9)

9
ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py


@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)])
@pytest.mark.parametrize("num_var_len", [1, 2])
@pytest.mark.parametrize("num_visual", [0, 1])
def test_var_len_obs_ppo(num_visual, num_var_len, action_sizes):
@pytest.mark.parametrize("num_vector", [0, 1])
@pytest.mark.parametrize("num_vis", [0, 1])
def test_var_len_obs_ppo(num_vis, num_vector, num_var_len, action_sizes):
num_visual=num_visual,
num_vector=0,
num_visual=num_vis,
num_vector=num_vector,
num_var_len=num_var_len,
step_size=0.2,
)

23
ml-agents/mlagents/trainers/torch/attention.py


if not concat_self:
self.self_size = 0
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
if self.self_size > 0:
self.self_encoder = LinearEncoder(
self.self_size,
1,
embedding_size // 2,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.ent_encoder = LinearEncoder(
self.entity_size,
1,
embedding_size - (embedding_size // 2),
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
embedding_size if self.self_size > 0 else self.entity_size,
self.self_size + self.entity_size,
1,
embedding_size,
kernel_init=Initialization.Normal,

expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
# Concatenate all observations with self
entities = torch.cat(
[self.self_encoder(expanded_self), self.ent_encoder(entities)], dim=2
)
entities = torch.cat([expanded_self, entities], dim=2)
# Encode entities
encoded_entities = self.self_ent_encoder(entities)
return encoded_entities

31
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, Initialization
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil

if entity_max > 0:
entity_num_max += entity_max
if len(self.var_processors) > 0:
if sum(self.embedding_sizes):
self.x_self_encoder = LinearEncoder(
sum(self.embedding_sizes),
1,
self.h_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.h_size) ** 0.5,
)
n_layers = max(1, network_settings.num_layers - 2)
n_layers = max(1, network_settings.num_layers)
self.linear_encoder = LinearEncoder(total_enc_size, n_layers, self.h_size)
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
)
if self.use_lstm:
self.lstm = LSTM(self.h_size, self.m_size)

# Some inputs need to be processed with a variable length encoder
masks = get_zero_entities_mask(var_len_inputs)
embeddings: List[torch.Tensor] = []
for var_len_input, var_len_processor in zip(
var_len_inputs, self.var_processors
):
embeddings.append(var_len_processor(encoded_self, var_len_input))
if input_exist:
processed_self = self.x_self_encoder(encoded_self)
for var_len_input, var_len_processor in zip(
var_len_inputs, self.var_processors
):
embeddings.append(var_len_processor(processed_self, var_len_input))
else:
for var_len_input, var_len_processor in zip(
var_len_inputs, self.var_processors
):
embeddings.append(var_len_processor(None, var_len_input))
qkv = torch.cat(embeddings, dim=1)
attention_embedding = self.rsa(qkv, masks)
if not input_exist:

2
ml-agents/mlagents/trainers/torch/utils.py


for idx in var_len_indices:
var_encoders.append(
EntityEmbedding(
x_self_size=x_self_size,
x_self_size=0 if x_self_size == 0 else h_size,
entity_size=observation_specs[idx].shape[1],
entity_num_max_elements=observation_specs[idx].shape[0],
embedding_size=h_size,

正在加载...
取消
保存