|
|
|
|
|
|
total_enc_size = sum(self.embedding_sizes) |
|
|
|
n_layers = max(1, network_settings.num_layers) |
|
|
|
|
|
|
|
total_enc_size = sum(self.embedding_sizes) + encoded_act_size |
|
|
|
total_enc_size += encoded_act_size |
|
|
|
self.linear_encoder = LinearEncoder(total_enc_size, n_layers, self.h_size) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
|
|
|
|
|
|
|
# Constants don't work in Barracuda |
|
|
|
if actions is not None: |
|
|
|
inputs = torch.cat(encodes + [actions], dim=-1) |
|
|
|
else: |
|
|
|
inputs = torch.cat(encodes, dim=-1) |
|
|
|
encoding = self.linear_encoder(inputs) |
|
|
|
encoded_self = torch.cat([encoded_self, actions], dim=1) |
|
|
|
encoding = self.linear_encoder(encoded_self) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|