|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GaussianDistribution(nn.Module): |
|
|
|
def __init__(self, hidden_size, num_outputs, **kwargs): |
|
|
|
def __init__(self, hidden_size, num_outputs, conditional_sigma=False, **kwargs): |
|
|
|
self.conditional_sigma = conditional_sigma |
|
|
|
# self.log_sigma_sq = nn.Linear(hidden_size, num_outputs) |
|
|
|
self.log_sigma = nn.Parameter(torch.zeros(1, num_outputs, requires_grad=True)) |
|
|
|
# nn.init.xavier_uniform(self.log_sigma_sq.weight, gain=0.01) |
|
|
|
if conditional_sigma: |
|
|
|
self.log_sigma = nn.Linear(hidden_size, num_outputs) |
|
|
|
nn.init.xavier_uniform(self.log_sigma.weight, gain=0.01) |
|
|
|
else: |
|
|
|
self.log_sigma = nn.Parameter( |
|
|
|
torch.zeros(1, num_outputs, requires_grad=True) |
|
|
|
) |
|
|
|
# log_sig = torch.tanh(self.log_sigma_sq(inputs)) * 3.0 |
|
|
|
return [distributions.normal.Normal(loc=mu, scale=torch.exp(self.log_sigma))] |
|
|
|
if self.conditional_sigma: |
|
|
|
log_sigma = self.log_sigma(inputs) |
|
|
|
else: |
|
|
|
log_sigma = self.log_sigma |
|
|
|
return [distributions.normal.Normal(loc=mu, scale=torch.exp(log_sigma))] |
|
|
|
|
|
|
|
|
|
|
|
class MultiCategoricalDistribution(nn.Module): |
|
|
|