|
|
|
|
|
|
from mlagents.trainers.settings import NetworkSettings |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, Initialization |
|
|
|
from mlagents.trainers.torch.encoders import VectorInput |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.trajectory import ObsUtil |
|
|
|
|
|
|
if entity_max > 0: |
|
|
|
entity_num_max += entity_max |
|
|
|
if len(self.var_processors) > 0: |
|
|
|
if sum(self.embedding_sizes): |
|
|
|
self.x_self_encoder = LinearEncoder( |
|
|
|
sum(self.embedding_sizes), |
|
|
|
1, |
|
|
|
self.h_size, |
|
|
|
kernel_init=Initialization.Normal, |
|
|
|
kernel_gain=(0.125 / self.h_size) ** 0.5, |
|
|
|
) |
|
|
|
n_layers = max(1, network_settings.num_layers - 2) |
|
|
|
n_layers = max(1, network_settings.num_layers) |
|
|
|
self.linear_encoder = LinearEncoder(total_enc_size, n_layers, self.h_size) |
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
total_enc_size, network_settings.num_layers, self.h_size |
|
|
|
) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
self.lstm = LSTM(self.h_size, self.m_size) |
|
|
|
|
|
|
# Some inputs need to be processed with a variable length encoder |
|
|
|
masks = get_zero_entities_mask(var_len_inputs) |
|
|
|
embeddings: List[torch.Tensor] = [] |
|
|
|
for var_len_input, var_len_processor in zip( |
|
|
|
var_len_inputs, self.var_processors |
|
|
|
): |
|
|
|
embeddings.append(var_len_processor(encoded_self, var_len_input)) |
|
|
|
if input_exist: |
|
|
|
processed_self = self.x_self_encoder(encoded_self) |
|
|
|
for var_len_input, var_len_processor in zip( |
|
|
|
var_len_inputs, self.var_processors |
|
|
|
): |
|
|
|
embeddings.append(var_len_processor(processed_self, var_len_input)) |
|
|
|
else: |
|
|
|
for var_len_input, var_len_processor in zip( |
|
|
|
var_len_inputs, self.var_processors |
|
|
|
): |
|
|
|
embeddings.append(var_len_processor(None, var_len_input)) |
|
|
|
qkv = torch.cat(embeddings, dim=1) |
|
|
|
attention_embedding = self.rsa(qkv, masks) |
|
|
|
if not input_exist: |
|
|
|