您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
349 行
12 KiB
349 行
12 KiB
import abc
|
|
from typing import List
|
|
from mlagents.torch_utils import torch, nn
|
|
import numpy as np
|
|
import math
|
|
from mlagents.trainers.torch.layers import (
|
|
linear_layer,
|
|
Initialization,
|
|
LinearEncoder,
|
|
Swish,
|
|
)
|
|
from mlagents.trainers.torch.utils import ModelUtils
|
|
|
|
EPSILON = 1e-7 # Small value to avoid divide by zero
|
|
|
|
|
|
class DistInstance(nn.Module, abc.ABC):
|
|
@abc.abstractmethod
|
|
def sample(self) -> torch.Tensor:
|
|
"""
|
|
Return a sample from this distribution.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Returns the log probabilities of a particular value.
|
|
:param value: A value sampled from the distribution.
|
|
:returns: Log probabilities of the given value.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def entropy(self) -> torch.Tensor:
|
|
"""
|
|
Returns the entropy of this distribution.
|
|
"""
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def exported_model_output(self) -> torch.Tensor:
|
|
"""
|
|
Returns the tensor to be exported to ONNX for the distribution
|
|
"""
|
|
pass
|
|
|
|
|
|
class DiscreteDistInstance(DistInstance):
|
|
@abc.abstractmethod
|
|
def all_log_prob(self) -> torch.Tensor:
|
|
"""
|
|
Returns the log probabilities of all actions represented by this distribution.
|
|
"""
|
|
pass
|
|
|
|
|
|
class GaussianDistInstance(DistInstance):
|
|
def __init__(self, mean, std):
|
|
super().__init__()
|
|
self.mean = mean
|
|
self.std = std
|
|
|
|
def sample(self):
|
|
sample = self.mean + torch.randn_like(self.mean) * self.std
|
|
return sample
|
|
|
|
def log_prob(self, value):
|
|
var = self.std ** 2
|
|
log_scale = torch.log(self.std + EPSILON)
|
|
return (
|
|
-((value - self.mean) ** 2) / (2 * var + EPSILON)
|
|
- log_scale
|
|
- math.log(math.sqrt(2 * math.pi))
|
|
)
|
|
|
|
def pdf(self, value):
|
|
log_prob = self.log_prob(value)
|
|
return torch.exp(log_prob)
|
|
|
|
def entropy(self):
|
|
return torch.mean(
|
|
0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON),
|
|
dim=1,
|
|
keepdim=True,
|
|
) # Use equivalent behavior to TF
|
|
|
|
def exported_model_output(self):
|
|
return self.sample()
|
|
|
|
|
|
class TanhGaussianDistInstance(GaussianDistInstance):
|
|
def __init__(self, mean, std):
|
|
super().__init__(mean, std)
|
|
self.transform = torch.distributions.transforms.TanhTransform(cache_size=1)
|
|
|
|
def sample(self):
|
|
unsquashed_sample = super().sample()
|
|
squashed = self.transform(unsquashed_sample)
|
|
return squashed
|
|
|
|
def _inverse_tanh(self, value):
|
|
capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
|
|
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)
|
|
|
|
def log_prob(self, value):
|
|
unsquashed = self.transform.inv(value)
|
|
return super().log_prob(unsquashed) - self.transform.log_abs_det_jacobian(
|
|
unsquashed, value
|
|
)
|
|
|
|
|
|
class CategoricalDistInstance(DiscreteDistInstance):
|
|
def __init__(self, logits):
|
|
super().__init__()
|
|
self.logits = logits
|
|
self.probs = torch.softmax(self.logits, dim=-1)
|
|
|
|
def sample(self):
|
|
return torch.multinomial(self.probs, 1)
|
|
|
|
def pdf(self, value):
|
|
# This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]),
|
|
# but torch.diag is not supported by ONNX export.
|
|
idx = torch.arange(start=0, end=len(value)).unsqueeze(-1)
|
|
return torch.gather(
|
|
self.probs.permute(1, 0)[value.flatten().long()], -1, idx
|
|
).squeeze(-1)
|
|
|
|
def log_prob(self, value):
|
|
return torch.log(self.pdf(value) + EPSILON)
|
|
|
|
def all_log_prob(self):
|
|
return torch.log(self.probs + EPSILON)
|
|
|
|
def entropy(self):
|
|
return -torch.sum(
|
|
self.probs * torch.log(self.probs + EPSILON), dim=-1
|
|
).unsqueeze(-1)
|
|
|
|
def exported_model_output(self):
|
|
return self.all_log_prob()
|
|
|
|
|
|
class GaussianDistribution(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_outputs: int,
|
|
conditional_sigma: bool = False,
|
|
tanh_squash: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.conditional_sigma = conditional_sigma
|
|
self.mu = linear_layer(
|
|
hidden_size,
|
|
num_outputs,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=0.2,
|
|
bias_init=Initialization.Zero,
|
|
)
|
|
self.tanh_squash = tanh_squash
|
|
if conditional_sigma:
|
|
self.log_sigma = linear_layer(
|
|
hidden_size,
|
|
num_outputs,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=0.2,
|
|
bias_init=Initialization.Zero,
|
|
)
|
|
else:
|
|
self.log_sigma = nn.Parameter(
|
|
torch.zeros(1, num_outputs, requires_grad=True)
|
|
)
|
|
|
|
def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
|
|
mu = self.mu(inputs)
|
|
if self.conditional_sigma:
|
|
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
|
|
else:
|
|
# Expand so that entropy matches batch size. Note that we're using
|
|
# torch.cat here instead of torch.expand() becuase it is not supported in the
|
|
# verified version of Barracuda (1.0.2).
|
|
log_sigma = torch.cat([self.log_sigma] * inputs.shape[0], axis=0)
|
|
if self.tanh_squash:
|
|
return TanhGaussianDistInstance(mu, torch.exp(log_sigma))
|
|
else:
|
|
return GaussianDistInstance(mu, torch.exp(log_sigma))
|
|
|
|
|
|
class GaussianHyperNetwork(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_layers,
|
|
layer_size,
|
|
hidden_size,
|
|
num_outputs,
|
|
conditional_sigma,
|
|
tanh_squash,
|
|
num_goals,
|
|
):
|
|
super().__init__()
|
|
self._num_goals = num_goals
|
|
self.hidden_size = hidden_size
|
|
self.tanh_squash = tanh_squash
|
|
self.conditional_sigma = conditional_sigma
|
|
self.num_outputs = num_outputs
|
|
layers = []
|
|
layers.append(
|
|
linear_layer(
|
|
num_goals,
|
|
layer_size,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=0.1,
|
|
bias_init=Initialization.Zero,
|
|
)
|
|
)
|
|
layers.append(Swish())
|
|
for _ in range(num_layers - 1):
|
|
layers.append(
|
|
linear_layer(
|
|
layer_size,
|
|
layer_size,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=0.1,
|
|
bias_init=Initialization.Zero,
|
|
)
|
|
)
|
|
layers.append(Swish())
|
|
if conditional_sigma:
|
|
flat_output = linear_layer(
|
|
layer_size,
|
|
2 * (hidden_size * num_outputs + num_outputs),
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=0.1,
|
|
bias_init=Initialization.Zero,
|
|
)
|
|
self._log_sigma_w = None
|
|
else:
|
|
flat_output = linear_layer(
|
|
layer_size,
|
|
hidden_size * num_outputs + num_outputs,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=0.1,
|
|
bias_init=Initialization.Zero,
|
|
)
|
|
self._log_sigma_w = linear_layer(
|
|
num_goals,
|
|
num_outputs,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=0.1,
|
|
bias_init=Initialization.Zero,
|
|
)
|
|
self.hypernet = torch.nn.Sequential(*layers, flat_output)
|
|
|
|
def forward(self, inputs: torch.Tensor, goal: torch.Tensor):
|
|
goal_onehot = torch.nn.functional.one_hot(
|
|
goal[0].long(), self._num_goals
|
|
).float()
|
|
|
|
# cond (b, 2 * H * O + O
|
|
# not cond (b, H * O + O
|
|
flat_output_weights = self.hypernet(goal_onehot)
|
|
b = inputs.size(0)
|
|
inputs = inputs.unsqueeze(dim=1)
|
|
if self.conditional_sigma:
|
|
mu_w_log_sigma_w, mu_b, log_sigma_b = torch.split(
|
|
flat_output_weights,
|
|
[
|
|
2 * self.hidden_size * self.num_outputs,
|
|
self.num_outputs,
|
|
self.num_outputs,
|
|
],
|
|
dim=-1,
|
|
)
|
|
mu_w_log_sigma_w = torch.reshape(
|
|
mu_w_log_sigma_w, (b, 2 * self.hidden_size, self.num_outputs)
|
|
)
|
|
|
|
mu_w, log_sigma_w = torch.split(mu_w_log_sigma_w, self.hidden_size, dim=1)
|
|
log_sigma = torch.bmm(inputs, log_sigma_w)
|
|
log_sigma = log_sigma + log_sigma_b
|
|
log_sigma = log_sigma.squeeze()
|
|
log_sigma = torch.clamp(log_sigma, min=-20, max=2)
|
|
else:
|
|
mu_w, mu_b = torch.split(
|
|
flat_output_weights, self.hidden_size * self.num_outputs, dim=-1
|
|
)
|
|
mu_w = torch.reshape(mu_w, (b, self.hidden_size, self.num_outputs))
|
|
log_sigma = self._log_sigma_w(goal_onehot)
|
|
log_sigma = torch.squeeze(log_sigma)
|
|
|
|
mu = torch.bmm(inputs, mu_w)
|
|
mu = mu + mu_b
|
|
mu = mu.squeeze()
|
|
if self.tanh_squash:
|
|
return TanhGaussianDistInstance(mu, torch.exp(log_sigma))
|
|
else:
|
|
return GaussianDistInstance(mu, torch.exp(log_sigma))
|
|
|
|
|
|
class MultiCategoricalDistribution(nn.Module):
|
|
def __init__(self, hidden_size: int, act_sizes: List[int]):
|
|
super().__init__()
|
|
self.act_sizes = act_sizes
|
|
self.branches = self._create_policy_branches(hidden_size)
|
|
|
|
def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList:
|
|
branches = []
|
|
for size in self.act_sizes:
|
|
branch_output_layer = linear_layer(
|
|
hidden_size,
|
|
size,
|
|
kernel_init=Initialization.KaimingHeNormal,
|
|
kernel_gain=0.1,
|
|
bias_init=Initialization.Zero,
|
|
)
|
|
branches.append(branch_output_layer)
|
|
return nn.ModuleList(branches)
|
|
|
|
def _mask_branch(
|
|
self, logits: torch.Tensor, allow_mask: torch.Tensor
|
|
) -> torch.Tensor:
|
|
# Zero out masked logits, then subtract a large value. Technique mentionend here:
|
|
# https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barracuda-friendly.
|
|
block_mask = -1.0 * allow_mask + 1.0
|
|
# We do -1 * tensor + constant instead of constant - tensor because it seems
|
|
# Barracuda might swap the inputs of a "Sub" operation
|
|
logits = logits * allow_mask - 1e8 * block_mask
|
|
return logits
|
|
|
|
def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]:
|
|
split_masks = []
|
|
for idx, _ in enumerate(self.act_sizes):
|
|
start = int(np.sum(self.act_sizes[:idx]))
|
|
end = int(np.sum(self.act_sizes[: idx + 1]))
|
|
split_masks.append(masks[:, start:end])
|
|
return split_masks
|
|
|
|
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]:
|
|
# Todo - Support multiple branches in mask code
|
|
branch_distributions = []
|
|
masks = self._split_masks(masks)
|
|
for idx, branch in enumerate(self.branches):
|
|
logits = branch(inputs)
|
|
norm_logits = self._mask_branch(logits, masks[idx])
|
|
distribution = CategoricalDistInstance(norm_logits)
|
|
branch_distributions.append(distribution)
|
|
return branch_distributions
|