from typing import List, Tuple, NamedTuple, Optional from mlagents.torch_utils import torch, nn from mlagents.trainers.torch.distributions import ( DistInstance, DiscreteDistInstance, GaussianDistribution, GaussianHyperNetwork, MultiCategoricalDistribution, ) from mlagents.trainers.torch.agent_action import AgentAction from mlagents.trainers.torch.action_log_probs import ActionLogProbs from mlagents_envs.base_env import ActionSpec EPSILON = 1e-7 # Small value to avoid divide by zero class DistInstances(NamedTuple): """ A NamedTuple with fields corresponding the the DistInstance objects output by continuous and discrete distributions, respectively. Discrete distributions output a list of DistInstance objects whereas continuous distributions output a single DistInstance object. """ continuous: Optional[DistInstance] discrete: Optional[List[DiscreteDistInstance]] class ActionModel(nn.Module): def __init__( self, hidden_size: int, action_spec: ActionSpec, conditional_sigma: bool = False, tanh_squash: bool = False, ): """ A torch module that represents the action space of a policy. The ActionModel may contain a continuous distribution, a discrete distribution or both where construction depends on the action_spec. The ActionModel uses the encoded input of the network body to parameterize these distributions. The forward method of this module outputs the action, log probs, and entropies given the encoding from the network body. :params hidden_size: Size of the input to the ActionModel. :params action_spec: The ActionSpec defining the action space dimensions and distributions. :params conditional_sigma: Whether or not the std of a Gaussian is conditioned on state. :params tanh_squash: Whether to squash the output of a Gaussian with the tanh function. """ super().__init__() self.encoding_size = hidden_size self.action_spec = action_spec self._continuous_distribution = None self._discrete_distribution = None if self.action_spec.continuous_size > 0: self._continuous_distribution = GaussianHyperNetwork( num_layers=1, layer_size=256, hidden_size=self.encoding_size, num_outputs=self.action_spec.continuous_size, conditional_sigma=conditional_sigma, tanh_squash=tanh_squash, num_goals=2, ) if self.action_spec.discrete_size > 0: self._discrete_distribution = MultiCategoricalDistribution( self.encoding_size, self.action_spec.discrete_branches ) # During training, clipping is done in TorchPolicy, but we need to clip before ONNX # export as well. self._clip_action_on_export = not tanh_squash def _sample_action(self, dists: DistInstances) -> AgentAction: """ Samples actions from a DistInstances tuple :params dists: The DistInstances tuple :return: An AgentAction corresponding to the actions sampled from the DistInstances """ continuous_action: Optional[torch.Tensor] = None discrete_action: Optional[List[torch.Tensor]] = None # This checks None because mypy complains otherwise if dists.continuous is not None: continuous_action = dists.continuous.sample() if dists.discrete is not None: discrete_action = [] for discrete_dist in dists.discrete: discrete_action.append(discrete_dist.sample()) return AgentAction(continuous_action, discrete_action) def _get_dists( self, inputs: torch.Tensor, masks: torch.Tensor, goal: torch.Tensor ) -> DistInstances: """ Creates a DistInstances tuple using the continuous and discrete distributions :params inputs: The encoding from the network body :params masks: Action masks for discrete actions :return: A DistInstances tuple """ continuous_dist: Optional[DistInstance] = None discrete_dist: Optional[List[DiscreteDistInstance]] = None # This checks None because mypy complains otherwise if self._continuous_distribution is not None: continuous_dist = self._continuous_distribution(inputs, goal) if self._discrete_distribution is not None: discrete_dist = self._discrete_distribution(inputs, masks) return DistInstances(continuous_dist, discrete_dist) def _get_probs_and_entropy( self, actions: AgentAction, dists: DistInstances ) -> Tuple[ActionLogProbs, torch.Tensor]: """ Computes the log probabilites of the actions given distributions and entropies of the given distributions. :params actions: The AgentAction :params dists: The DistInstances tuple :return: An ActionLogProbs tuple and a torch tensor of the distribution entropies. """ entropies_list: List[torch.Tensor] = [] continuous_log_prob: Optional[torch.Tensor] = None discrete_log_probs: Optional[List[torch.Tensor]] = None all_discrete_log_probs: Optional[List[torch.Tensor]] = None # This checks None because mypy complains otherwise if dists.continuous is not None: continuous_log_prob = dists.continuous.log_prob(actions.continuous_tensor) entropies_list.append(dists.continuous.entropy()) if dists.discrete is not None: discrete_log_probs = [] all_discrete_log_probs = [] for discrete_action, discrete_dist in zip( actions.discrete_list, dists.discrete # type: ignore ): discrete_log_prob = discrete_dist.log_prob(discrete_action) entropies_list.append(discrete_dist.entropy()) discrete_log_probs.append(discrete_log_prob) all_discrete_log_probs.append(discrete_dist.all_log_prob()) action_log_probs = ActionLogProbs( continuous_log_prob, discrete_log_probs, all_discrete_log_probs ) entropies = torch.cat(entropies_list, dim=1) return action_log_probs, entropies def evaluate( self, inputs: torch.Tensor, masks: torch.Tensor, actions: AgentAction, goal: torch.Tensor, ) -> Tuple[ActionLogProbs, torch.Tensor]: """ Given actions and encoding from the network body, gets the distributions and computes the log probabilites and entropies. :params inputs: The encoding from the network body :params masks: Action masks for discrete actions :params actions: The AgentAction :return: An ActionLogProbs tuple and a torch tensor of the distribution entropies. """ dists = self._get_dists(inputs, masks, goal) log_probs, entropies = self._get_probs_and_entropy(actions, dists) # Use the sum of entropy across actions, not the mean entropy_sum = torch.sum(entropies, dim=1) return log_probs, entropy_sum def get_action_out( self, inputs: torch.Tensor, masks: torch.Tensor, goal: torch.Tensor ) -> torch.Tensor: """ Gets the tensors corresponding to the output of the policy network to be used for inference. Called by the Actor's forward call. :params inputs: The encoding from the network body :params masks: Action masks for discrete actions :return: A tuple of torch tensors corresponding to the inference output """ dists = self._get_dists(inputs, masks, goal) continuous_out, discrete_out, action_out_deprecated = None, None, None if self.action_spec.continuous_size > 0 and dists.continuous is not None: continuous_out = dists.continuous.exported_model_output() action_out_deprecated = dists.continuous.exported_model_output() if self._clip_action_on_export: continuous_out = torch.clamp(continuous_out, -3, 3) / 3 action_out_deprecated = torch.clamp(action_out_deprecated, -3, 3) / 3 if self.action_spec.discrete_size > 0 and dists.discrete is not None: discrete_out_list = [ discrete_dist.exported_model_output() for discrete_dist in dists.discrete ] discrete_out = torch.cat(discrete_out_list, dim=1) action_out_deprecated = torch.cat(discrete_out_list, dim=1) # deprecated action field does not support hybrid action if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0: action_out_deprecated = None return continuous_out, discrete_out, action_out_deprecated def forward( self, inputs: torch.Tensor, masks: torch.Tensor, goal: torch.Tensor ) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor]: """ The forward method of this module. Outputs the action, log probs, and entropies given the encoding from the network body. :params inputs: The encoding from the network body :params masks: Action masks for discrete actions :return: Given the input, an AgentAction of the actions generated by the policy and the corresponding ActionLogProbs and entropies. """ dists = self._get_dists(inputs, masks, goal) actions = self._sample_action(dists) log_probs, entropies = self._get_probs_and_entropy(actions, dists) # Use the sum of entropy across actions, not the mean entropy_sum = torch.sum(entropies, dim=1) return (actions, log_probs, entropy_sum)