|
|
|
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads |
|
|
|
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, Initialization |
|
|
|
from mlagents.trainers.torch.encoders import VectorInput |
|
|
|
from mlagents.trainers.torch.encoders import VectorInput, Identity |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.trajectory import ObsUtil |
|
|
|
from mlagents.trainers.torch.attention import ( |
|
|
|
|
|
|
var_len_inputs = [] # The list of variable length inputs |
|
|
|
|
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
if processor is not None: |
|
|
|
if not isinstance(processor, Identity): |
|
|
|
# The input can be encoded without having to process other inputs |
|
|
|
obs_input = inputs[idx] |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
|
|
|
# Some inputs need to be processed with a variable length encoder |
|
|
|
masks = get_zero_entities_mask(var_len_inputs) |
|
|
|
embeddings: List[torch.Tensor] = [] |
|
|
|
if input_exist: |
|
|
|
processed_self = self.x_self_encoder(encoded_self) |
|
|
|
for var_len_input, var_len_processor in zip( |
|
|
|
var_len_inputs, self.var_processors |
|
|
|
): |
|
|
|
embeddings.append(var_len_processor(processed_self, var_len_input)) |
|
|
|
else: |
|
|
|
for var_len_input, var_len_processor in zip( |
|
|
|
var_len_inputs, self.var_processors |
|
|
|
): |
|
|
|
embeddings.append(var_len_processor(None, var_len_input)) |
|
|
|
processed_self = self.x_self_encoder(encoded_self) if input_exist else None |
|
|
|
for var_len_input, var_len_processor in zip( |
|
|
|
var_len_inputs, self.var_processors |
|
|
|
): |
|
|
|
embeddings.append(var_len_processor(processed_self, var_len_input)) |
|
|
|
qkv = torch.cat(embeddings, dim=1) |
|
|
|
attention_embedding = self.rsa(qkv, masks) |
|
|
|
if not input_exist: |
|
|
|
|
|
|
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1) |
|
|
|
|
|
|
|
if not input_exist: |
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|
raise Exception( |
|
|
|
"The trainer was unable to process any of the provided inputs. " |
|
|
|
"Make sure the trained agents has at least one sensor attached to them." |
|
|
|
) |
|
|
|
|
|
|
|
# Constants don't work in Barracuda |
|
|
|
if actions is not None: |
|
|
|