|
|
|
|
|
|
encoder_input = self.get_state_encoding(mini_batch) |
|
|
|
if self._settings.use_actions: |
|
|
|
actions = self.get_action_input(mini_batch) |
|
|
|
dones = torch.as_tensor(mini_batch["done"], dtype=torch.float) |
|
|
|
dones = torch.as_tensor(mini_batch["done"], dtype=torch.float).unsqueeze(1) |
|
|
|
encoder_input = torch.cat([encoder_input, actions, dones], dim=1) |
|
|
|
hidden = self.encoder(encoder_input) |
|
|
|
z_mu: Optional[torch.Tensor] = None |
|
|
|
|
|
|
policy_action = self.get_action_input(policy_batch) |
|
|
|
expert_action = self.get_action_input(policy_batch) |
|
|
|
action_epsilon = torch.rand(policy_action.shape) |
|
|
|
policy_dones = torch.as_tensor(policy_batch["done"], dtype=torch.float) |
|
|
|
expert_dones = torch.as_tensor(expert_batch["done"], dtype=torch.float) |
|
|
|
policy_dones = torch.as_tensor( |
|
|
|
policy_batch["done"], dtype=torch.float |
|
|
|
).unsqueeze(1) |
|
|
|
expert_dones = torch.as_tensor( |
|
|
|
expert_batch["done"], dtype=torch.float |
|
|
|
).unsqueeze(1) |
|
|
|
dones_epsilon = torch.rand(policy_dones.shape) |
|
|
|
encoder_input = torch.cat( |
|
|
|
[ |
|
|
|