|
|
|
|
|
|
normalize=self.normalize, |
|
|
|
) |
|
|
|
|
|
|
|
total_enc_size = sum(self.embedding_sizes) + encoded_act_size |
|
|
|
|
|
|
|
total_enc_size = sum(self.embedding_sizes) + encoded_act_size - 9 |
|
|
|
|
|
|
|
self.surrogate_predictor = torch.nn.Linear(self.h_size, 9) |
|
|
|
|
|
|
|
|
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
total_enc_size, network_settings.num_layers, self.h_size |
|
|
|
) |
|
|
|
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
retrun_target = False |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
|
|
|
|
if obs_input.shape[1] == 9: |
|
|
|
target = obs_input |
|
|
|
if retrun_target: |
|
|
|
return target |
|
|
|
else: |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
if len(encodes) == 0: |
|
|
|
raise Exception("No valid inputs to network.") |
|
|
|
|
|
|
|
|
|
|
encoding, memories = self.lstm(encoding, memories) |
|
|
|
encoding = encoding.reshape([-1, self.m_size // 2]) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
def get_surrogate_loss(self, inputs: List[torch.Tensor]) -> torch.Tensor: |
|
|
|
prediction, _ = self.forward(inputs) |
|
|
|
prediction = self.surrogate_predictor(prediction) |
|
|
|
target = self.forward(inputs, retrun_target=True) |
|
|
|
loss = torch.sum((prediction - target) ** 2, dim=1) |
|
|
|
loss = torch.mean(loss) |
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ValueNetwork(nn.Module): |
|
|
|