|
|
|
|
|
|
from mlagents.torch_utils import torch, nn |
|
|
|
import numpy as np |
|
|
|
import math |
|
|
|
from mlagents.trainers.torch.layers import ( |
|
|
|
linear_layer, |
|
|
|
Initialization, |
|
|
|
LinearEncoder, |
|
|
|
Swish, |
|
|
|
) |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.layers import linear_layer, Initialization, HyperNetwork |
|
|
|
|
|
|
|
EPSILON = 1e-7 # Small value to avoid divide by zero |
|
|
|
|
|
|
|
|
|
|
num_outputs, |
|
|
|
conditional_sigma, |
|
|
|
tanh_squash, |
|
|
|
num_goals, |
|
|
|
goal_size, |
|
|
|
self._num_goals = num_goals |
|
|
|
self._num_goals = goal_size |
|
|
|
layers = [] |
|
|
|
layers.append( |
|
|
|
linear_layer( |
|
|
|
num_goals, |
|
|
|
layer_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=0.1, |
|
|
|
bias_init=Initialization.Zero, |
|
|
|
) |
|
|
|
self.hypernet = HyperNetwork( |
|
|
|
hidden_size, num_outputs, goal_size, num_layers, layer_size |
|
|
|
layers.append(Swish()) |
|
|
|
for _ in range(num_layers - 1): |
|
|
|
layers.append( |
|
|
|
linear_layer( |
|
|
|
layer_size, |
|
|
|
layer_size, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=0.1, |
|
|
|
bias_init=Initialization.Zero, |
|
|
|
) |
|
|
|
) |
|
|
|
layers.append(Swish()) |
|
|
|
flat_output = linear_layer( |
|
|
|
layer_size, |
|
|
|
2 * (hidden_size * num_outputs + num_outputs), |
|
|
|
self.log_sigma = linear_layer( |
|
|
|
hidden_size, |
|
|
|
num_outputs, |
|
|
|
kernel_gain=0.1, |
|
|
|
kernel_gain=0.2, |
|
|
|
self._log_sigma_w = None |
|
|
|
flat_output = linear_layer( |
|
|
|
layer_size, |
|
|
|
hidden_size * num_outputs + num_outputs, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=0.1, |
|
|
|
bias_init=Initialization.Zero, |
|
|
|
self.log_sigma = nn.Parameter( |
|
|
|
torch.zeros(1, num_outputs, requires_grad=True) |
|
|
|
self._log_sigma_w = linear_layer( |
|
|
|
num_goals, |
|
|
|
num_outputs, |
|
|
|
kernel_init=Initialization.KaimingHeNormal, |
|
|
|
kernel_gain=0.1, |
|
|
|
bias_init=Initialization.Zero, |
|
|
|
) |
|
|
|
self.hypernet = torch.nn.Sequential(*layers, flat_output) |
|
|
|
goal_onehot = torch.nn.functional.one_hot( |
|
|
|
goal[0].long(), self._num_goals |
|
|
|
).float() |
|
|
|
|
|
|
|
# cond (b, 2 * H * O + O |
|
|
|
# not cond (b, H * O + O |
|
|
|
flat_output_weights = self.hypernet(goal_onehot) |
|
|
|
b = inputs.size(0) |
|
|
|
inputs = inputs.unsqueeze(dim=1) |
|
|
|
mu = self.mu(inputs, goal) |
|
|
|
mu_w_log_sigma_w, mu_b, log_sigma_b = torch.split( |
|
|
|
flat_output_weights, |
|
|
|
[ |
|
|
|
2 * self.hidden_size * self.num_outputs, |
|
|
|
self.num_outputs, |
|
|
|
self.num_outputs, |
|
|
|
], |
|
|
|
dim=-1, |
|
|
|
) |
|
|
|
mu_w_log_sigma_w = torch.reshape( |
|
|
|
mu_w_log_sigma_w, (b, 2 * self.hidden_size, self.num_outputs) |
|
|
|
) |
|
|
|
|
|
|
|
mu_w, log_sigma_w = torch.split(mu_w_log_sigma_w, self.hidden_size, dim=1) |
|
|
|
log_sigma = torch.bmm(inputs, log_sigma_w) |
|
|
|
log_sigma = log_sigma + log_sigma_b |
|
|
|
log_sigma = log_sigma.squeeze() |
|
|
|
log_sigma = torch.clamp(log_sigma, min=-20, max=2) |
|
|
|
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2) |
|
|
|
mu_w, mu_b = torch.split( |
|
|
|
flat_output_weights, self.hidden_size * self.num_outputs, dim=-1 |
|
|
|
) |
|
|
|
mu_w = torch.reshape(mu_w, (b, self.hidden_size, self.num_outputs)) |
|
|
|
log_sigma = self._log_sigma_w(goal_onehot) |
|
|
|
log_sigma = torch.squeeze(log_sigma) |
|
|
|
|
|
|
|
mu = torch.bmm(inputs, mu_w) |
|
|
|
mu = mu + mu_b |
|
|
|
mu = mu.squeeze() |
|
|
|
# Expand so that entropy matches batch size. Note that we're using |
|
|
|
# torch.cat here instead of torch.expand() becuase it is not supported in the |
|
|
|
# verified version of Barracuda (1.0.2). |
|
|
|
log_sigma = torch.cat([self.log_sigma] * inputs.shape[0], axis=0) |
|
|
|
if self.tanh_squash: |
|
|
|
return TanhGaussianDistInstance(mu, torch.exp(log_sigma)) |
|
|
|
else: |
|
|
|