|
|
|
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, Initialization |
|
|
|
from mlagents.trainers.torch.encoders import VectorInput, Identity |
|
|
|
from mlagents.trainers.torch.encoders import VectorInput |
|
|
|
EntityEmbedding, |
|
|
|
ResidualSelfAttention, |
|
|
|
get_zero_entities_mask, |
|
|
|
) |
|
|
|
|
|
|
else 0 |
|
|
|
) |
|
|
|
|
|
|
|
self.processors, self.var_processors, self.embedding_sizes = ModelUtils.create_input_processors( |
|
|
|
self.processors, self.embedding_sizes = ModelUtils.create_input_processors( |
|
|
|
observation_specs, |
|
|
|
self.h_size, |
|
|
|
network_settings.vis_encode_type, |
|
|
|
|
|
|
entity_num_max: int = 0 |
|
|
|
for var_processor in self.var_processors: |
|
|
|
entity_max: int = var_processor.entity_num_max_elements |
|
|
|
var_processors = [p for p in self.processors if isinstance(p, EntityEmbedding)] |
|
|
|
for processor in var_processors: |
|
|
|
entity_max: int = processor.entity_num_max_elements |
|
|
|
if len(self.var_processors) > 0: |
|
|
|
if len(var_processors) > 0: |
|
|
|
if sum(self.embedding_sizes): |
|
|
|
self.x_self_encoder = LinearEncoder( |
|
|
|
sum(self.embedding_sizes), |
|
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encodes = [] |
|
|
|
var_len_inputs = [] # The list of variable length inputs |
|
|
|
var_len_processors = [ |
|
|
|
p for p in self.processors if isinstance(p, EntityEmbedding) |
|
|
|
] |
|
|
|
if not isinstance(processor, Identity): |
|
|
|
if not isinstance(processor, EntityEmbedding): |
|
|
|
# The input can be encoded without having to process other inputs |
|
|
|
obs_input = inputs[idx] |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
|
|
|
masks = get_zero_entities_mask(var_len_inputs) |
|
|
|
embeddings: List[torch.Tensor] = [] |
|
|
|
processed_self = self.x_self_encoder(encoded_self) if input_exist else None |
|
|
|
for var_len_input, var_len_processor in zip( |
|
|
|
var_len_inputs, self.var_processors |
|
|
|
): |
|
|
|
embeddings.append(var_len_processor(processed_self, var_len_input)) |
|
|
|
for var_len_input, processor in zip(var_len_inputs, var_len_processors): |
|
|
|
embeddings.append(processor(processed_self, var_len_input)) |
|
|
|
qkv = torch.cat(embeddings, dim=1) |
|
|
|
attention_embedding = self.rsa(qkv, masks) |
|
|
|
if not input_exist: |
|
|
|
|
|
|
end = start + vec_size |
|
|
|
inputs.append(concatenated_vec_obs[:, start:end]) |
|
|
|
start = end |
|
|
|
elif enc is not None: |
|
|
|
elif isinstance(enc, EntityEmbedding): |
|
|
|
inputs.append(var_len_inputs[var_len_index]) |
|
|
|
var_len_index += 1 |
|
|
|
else: # visual input |
|
|
|
else: |
|
|
|
inputs.append(var_len_inputs[var_len_index]) |
|
|
|
var_len_index += 1 |
|
|
|
|
|
|
|
# End of code to convert the vec and vis obs into a list of inputs for the network |
|
|
|
encoding, memories_out = self.network_body( |
|
|
|
inputs, memories=memories, sequence_length=1 |
|
|
|