|
|
|
|
|
|
import numpy as np |
|
|
|
import math |
|
|
|
from mlagents.trainers.torch.layers import linear_layer, Initialization |
|
|
|
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance, GaussianDistribution, MultiCategoricalDistribution |
|
|
|
from mlagents.trainers.torch.distributions import ( |
|
|
|
DistInstance, |
|
|
|
DiscreteDistInstance, |
|
|
|
GaussianDistribution, |
|
|
|
MultiCategoricalDistribution, |
|
|
|
) |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils, AgentAction, ActionLogProbs |
|
|
|
|
|
|
|
|
|
|
|
class ActionModel(nn.Module): |
|
|
|
def __init__( |
|
|
|
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.encoding_size = hidden_size |
|
|
|
self.continuous_act_size = action_spec.continuous_action_size |
|
|
|
self.discrete_act_branches = action_spec.discrete_action_branches |
|
|
|
self.discrete_act_size = action_spec.discrete_action_size |
|
|
|
|
|
|
|
self._split_list : List[int] = [] |
|
|
|
if self.continuous_act_size > 0: |
|
|
|
self._distributions.append(GaussianDistribution( |
|
|
|
|
|
|
|
if self.action_spec.continuous_size > 0: |
|
|
|
self._distributions.append( |
|
|
|
GaussianDistribution( |
|
|
|
self.continuous_act_size, |
|
|
|
self.action_spec.continuous_size, |
|
|
|
self._split_list.append(self.continuous_act_size) |
|
|
|
if self.discrete_act_size > 0: |
|
|
|
self._distributions.append(MultiCategoricalDistribution(self.encoding_size, self.discrete_act_branches)) |
|
|
|
self._split_list += [1 for _ in range(self.discrete_act_size)] |
|
|
|
if self.action_spec.discrete_size > 0: |
|
|
|
self._distributions.append( |
|
|
|
MultiCategoricalDistribution( |
|
|
|
self.encoding_size, self.action_spec.discrete_branches |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
def _sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]: |
|
|
|
""" |
|
|
|
|
|
|
actions.append(action) |
|
|
|
return actions |
|
|
|
|
|
|
|
def _get_dists(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[List[DistInstance], List[DiscreteDistInstance]]: |
|
|
|
def _get_dists( |
|
|
|
self, inputs: torch.Tensor, masks: torch.Tensor |
|
|
|
) -> Tuple[List[DistInstance], List[DiscreteDistInstance]]: |
|
|
|
distribution_instances: List[DistInstance] = [] |
|
|
|
for distribution in self._distributions: |
|
|
|
dist_instances = distribution(inputs, masks) |
|
|
|
|
|
|
|
|
|
|
def evaluate(self, inputs: torch.Tensor, masks: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
def evaluate( |
|
|
|
self, inputs: torch.Tensor, masks: torch.Tensor, actions: AgentAction |
|
|
|
) -> Tuple[ActionLogProbs, torch.Tensor]: |
|
|
|
split_actions = torch.split(actions, self._split_list, dim=1) |
|
|
|
action_lists : List[torch.Tensor] = [] |
|
|
|
for split_action in split_actions: |
|
|
|
action_list = [split_action[..., i] for i in range(split_action.shape[-1])] |
|
|
|
action_lists += action_list |
|
|
|
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_lists, dists) |
|
|
|
return log_probs, entropies |
|
|
|
action_list = actions.to_tensor_list() |
|
|
|
log_probs_list, entropies, _ = ModelUtils.get_probs_and_entropy( |
|
|
|
action_list, dists |
|
|
|
) |
|
|
|
log_probs = ActionLogProbs.create(log_probs_list, self.action_spec) |
|
|
|
# Use the sum of entropy across actions, not the mean |
|
|
|
entropy_sum = torch.sum(entropies, dim=1) |
|
|
|
return log_probs, entropy_sum |
|
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
def forward( |
|
|
|
self, inputs: torch.Tensor, masks: torch.Tensor |
|
|
|
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor]: |
|
|
|
action_outs : List[torch.Tensor] = [] |
|
|
|
action_lists = self._sample_action(dists) |
|
|
|
for action_list, dist in zip(action_lists, dists): |
|
|
|
action_out = action_list.unsqueeze(-1) |
|
|
|
action_outs.append(dist.structure_action(action_out)) |
|
|
|
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_lists, dists) |
|
|
|
action = torch.cat(action_outs, dim=1) |
|
|
|
return (action, log_probs, entropies) |
|
|
|
action_list = self._sample_action(dists) |
|
|
|
log_probs_list, entropies, all_logs_list = ModelUtils.get_probs_and_entropy( |
|
|
|
action_list, dists |
|
|
|
) |
|
|
|
actions = AgentAction.create(action_list, self.action_spec) |
|
|
|
log_probs = ActionLogProbs.create( |
|
|
|
log_probs_list, self.action_spec, all_logs_list |
|
|
|
) |
|
|
|
# Use the sum of entropy across actions, not the mean |
|
|
|
entropy_sum = torch.sum(entropies, dim=1) |
|
|
|
return (actions, log_probs, entropies) |