您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
112 行
5.3 KiB
112 行
5.3 KiB
import abc
|
|
from typing import List, Tuple
|
|
from mlagents.torch_utils import torch, nn
|
|
import numpy as np
|
|
import math
|
|
from mlagents.trainers.torch.layers import linear_layer, Initialization
|
|
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance, GaussianDistribution, MultiCategoricalDistribution
|
|
|
|
from mlagents.trainers.torch.utils import ModelUtils
|
|
|
|
EPSILON = 1e-7 # Small value to avoid divide by zero
|
|
|
|
|
|
|
|
|
|
class ActionModel(nn.Module, abc.ABC):
|
|
#@abc.abstractmethod
|
|
#def entropy(self, action_list: np.ndarray) -> torch.Tensor:
|
|
# pass
|
|
#@abc.abstractmethod
|
|
#def log_probs(self, action_list: np.ndarray) -> torch.Tensor:
|
|
# pass
|
|
|
|
def _sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:
|
|
actions = []
|
|
for action_dist in dists:
|
|
action = action_dist.sample()
|
|
actions.append(action)
|
|
return actions
|
|
|
|
@abc.abstractmethod
|
|
def forward(self, inputs: torch.Tensor, masks: torch.Tensor):
|
|
pass
|
|
|
|
class HybridActionModel(ActionModel):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
continuous_act_size: int,
|
|
discrete_act_size: List[int],
|
|
conditional_sigma: bool = False,
|
|
tanh_squash: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.encoding_size = hidden_size
|
|
self.continuous_act_size = continuous_act_size
|
|
self.discrete_act_size = discrete_act_size
|
|
self.continuous_distribution = None #: List[GaussianDistribution] = []
|
|
self.discrete_distribution = None #: List[MultiCategoricalDistribution] = []
|
|
if continuous_act_size > 0:
|
|
self.continuous_distribution = GaussianDistribution(
|
|
self.encoding_size,
|
|
continuous_act_size,
|
|
conditional_sigma=conditional_sigma,
|
|
tanh_squash=tanh_squash,
|
|
)
|
|
if len(discrete_act_size) > 0:
|
|
self.discrete_distribution = MultiCategoricalDistribution(self.encoding_size, discrete_act_size)
|
|
|
|
|
|
def evaluate(self, inputs: torch.Tensor, masks: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
continuous_dists, discrete_dists = self._get_dists(inputs, masks)
|
|
continuous_actions, discrete_actions = torch.split(actions, [self.continuous_act_size, len(self.discrete_act_size)], dim=1)
|
|
|
|
|
|
continuous_action_list = [continuous_actions[..., i] for i in range(continuous_actions.shape[-1])]
|
|
continuous_log_probs, continuous_entropies, _ = ModelUtils.get_probs_and_entropy(continuous_action_list, continuous_dists)
|
|
|
|
discrete_action_list = [discrete_actions[:, i] for i in range(len(self.discrete_act_size))]
|
|
discrete_log_probs, discrete_entropies, _ = ModelUtils.get_probs_and_entropy(discrete_action_list, discrete_dists)
|
|
|
|
log_probs = torch.cat([continuous_log_probs, discrete_log_probs], dim=1)
|
|
entropies = torch.cat([continuous_entropies, torch.mean(discrete_entropies, dim=0).unsqueeze(0)], dim=1)
|
|
return log_probs, entropies
|
|
|
|
def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
|
|
continuous_dists, discrete_dists = self._get_dists(inputs, masks)
|
|
dists = continuous_dists + discrete_dists
|
|
return torch.cat([dist.exported_model_output() for dist in dists], dim=1)
|
|
|
|
def _get_dists(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[List[DistInstance], List[DiscreteDistInstance]]:
|
|
#continuous_distributions: List[DistInstance] = []
|
|
#discrete_distributions: List[DiscreteDistInstance] = []
|
|
continuous_dist_instances = self.continuous_distribution(inputs)# for continuous_dist in self.continuous_distributions]
|
|
discrete_dist_instances = self.discrete_distribution(inputs, masks)# for discrete_dist in self.discrete_distributions]
|
|
#for continuous_dist in self.continuous_distributions:
|
|
# continuous_distributions += continuous_dist(inputs)
|
|
#for discrete_dist in self.discrete_distributions:
|
|
# discrete_distributions += discrete_dist(inputs, masks)
|
|
return continuous_dist_instances, discrete_dist_instances
|
|
|
|
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
continuous_dists, discrete_dists = self._get_dists(inputs, masks)
|
|
|
|
continuous_action_list = self._sample_action(continuous_dists)
|
|
continuous_entropies, continuous_log_probs, continuous_all_probs = ModelUtils.get_probs_and_entropy(
|
|
continuous_action_list, continuous_dists
|
|
)
|
|
continuous_actions = torch.stack(continuous_action_list, dim=-1)
|
|
continuous_actions = continuous_actions[:, :, 0]
|
|
|
|
discrete_action_list = self._sample_action(discrete_dists)
|
|
discrete_entropies, discrete_log_probs, discrete_all_probs = ModelUtils.get_probs_and_entropy(
|
|
discrete_action_list, discrete_dists
|
|
)
|
|
discrete_actions = torch.stack(discrete_action_list, dim=-1)
|
|
discrete_actions = discrete_actions[:, 0, :]
|
|
|
|
action = torch.cat([continuous_actions, discrete_actions.type(torch.float)], dim=1)
|
|
log_probs = torch.cat([continuous_log_probs, discrete_log_probs], dim=1)
|
|
entropies = torch.cat([continuous_entropies, discrete_entropies], dim=1)
|
|
return (action, log_probs, entropies)
|