)
def forward(self, inputs: torch.Tensor, goal: torch.Tensor):
mu = self.mu(inputs, goal)
mu = self.hypernet(inputs, goal)
if self.conditional_sigma:
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
else: