|
|
|
|
|
|
from typing import List, Tuple |
|
|
|
from typing import List, Tuple, NamedTuple, Optional |
|
|
|
DiscreteDistInstance, |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils, AgentAction, ActionLogProbs |
|
|
|
from mlagents.trainers.torch.utils import AgentAction, ActionLogProbs |
|
|
|
class DistInstances(NamedTuple): |
|
|
|
continuous: DistInstance |
|
|
|
discrete: List[DiscreteDistInstance] |
|
|
|
|
|
|
|
|
|
|
|
class ActionModel(nn.Module): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
|
|
|
super().__init__() |
|
|
|
self.encoding_size = hidden_size |
|
|
|
self.action_spec = action_spec |
|
|
|
self._distributions = torch.nn.ModuleList() |
|
|
|
self._continuous_distribution = None |
|
|
|
self._discrete_distribution = None |
|
|
|
self._distributions.append( |
|
|
|
GaussianDistribution( |
|
|
|
self.encoding_size, |
|
|
|
self.action_spec.continuous_size, |
|
|
|
conditional_sigma=conditional_sigma, |
|
|
|
tanh_squash=tanh_squash, |
|
|
|
) |
|
|
|
self._continuous_distribution = GaussianDistribution( |
|
|
|
self.encoding_size, |
|
|
|
self.action_spec.continuous_size, |
|
|
|
conditional_sigma=conditional_sigma, |
|
|
|
tanh_squash=tanh_squash, |
|
|
|
self._distributions.append( |
|
|
|
MultiCategoricalDistribution( |
|
|
|
self.encoding_size, self.action_spec.discrete_branches |
|
|
|
) |
|
|
|
self._discrete_distribution = MultiCategoricalDistribution( |
|
|
|
self.encoding_size, self.action_spec.discrete_branches |
|
|
|
def _sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]: |
|
|
|
def _sample_action(self, dists: DistInstances) -> AgentAction: |
|
|
|
actions = [] |
|
|
|
for action_dist in dists: |
|
|
|
action = action_dist.sample() |
|
|
|
actions.append(action) |
|
|
|
return actions |
|
|
|
continuous_action: Optional[torch.Tensor] = None |
|
|
|
discrete_action: Optional[List[torch.Tensor]] = None |
|
|
|
if self.action_spec.continuous_size > 0: |
|
|
|
continuous_action = dists.continuous.sample() |
|
|
|
if self.action_spec.discrete_size > 0: |
|
|
|
discrete_action = [] |
|
|
|
for discrete_dist in dists.discrete: |
|
|
|
discrete_action.append(discrete_dist.sample()) |
|
|
|
return AgentAction(continuous_action, discrete_action) |
|
|
|
|
|
|
|
def _get_dists(self, inputs: torch.Tensor, masks: torch.Tensor) -> DistInstances: |
|
|
|
continuous_dist: Optional[DistInstance] = None |
|
|
|
discrete_dist: Optional[List[DiscreteDistInstance]] = None |
|
|
|
if self.action_spec.continuous_size > 0: |
|
|
|
continuous_dist = self._continuous_distribution(inputs, masks) |
|
|
|
if self.action_spec.discrete_size > 0: |
|
|
|
discrete_dist = self._discrete_distribution(inputs, masks) |
|
|
|
return DistInstances(continuous_dist, discrete_dist) |
|
|
|
|
|
|
|
def _get_probs_and_entropy( |
|
|
|
self, actions: AgentAction, dists: DistInstances |
|
|
|
) -> Tuple[ActionLogProbs, torch.Tensor]: |
|
|
|
def _get_dists( |
|
|
|
self, inputs: torch.Tensor, masks: torch.Tensor |
|
|
|
) -> List[DistInstance]: |
|
|
|
distribution_instances: List[DistInstance] = [] |
|
|
|
for distribution in self._distributions: |
|
|
|
dist_instances = distribution(inputs, masks) |
|
|
|
for dist_instance in dist_instances: |
|
|
|
distribution_instances.append(dist_instance) |
|
|
|
return distribution_instances |
|
|
|
entropies_list: List[torch.Tensor] = [] |
|
|
|
continuous_log_prob: Optional[torch.Tensor] = None |
|
|
|
discrete_log_probs: Optional[List[torch.Tensor]] = None |
|
|
|
all_discrete_log_probs: Optional[List[torch.Tensor]] = None |
|
|
|
if self.action_spec.continuous_size > 0: |
|
|
|
continuous_log_prob = dists.continuous.log_prob(actions.continuous_tensor) |
|
|
|
entropies_list.append(dists.continuous.entropy()) |
|
|
|
if self.action_spec.discrete_size > 0: |
|
|
|
discrete_log_probs = [] |
|
|
|
all_discrete_log_probs = [] |
|
|
|
for discrete_action, discrete_dist in zip( |
|
|
|
actions.discrete_list, dists.discrete |
|
|
|
): |
|
|
|
discrete_log_prob = discrete_dist.log_prob(discrete_action) |
|
|
|
entropies_list.append(discrete_dist.entropy()) |
|
|
|
discrete_log_probs.append(discrete_log_prob) |
|
|
|
all_discrete_log_probs.append(discrete_dist.all_log_prob()) |
|
|
|
action_log_probs = ActionLogProbs( |
|
|
|
continuous_log_prob, discrete_log_probs, all_discrete_log_probs |
|
|
|
) |
|
|
|
entropies = torch.cat(entropies_list, dim=1) |
|
|
|
return action_log_probs, entropies |
|
|
|
action_list = actions.to_tensor_list() |
|
|
|
log_probs_list, entropies, _ = ModelUtils.get_probs_and_entropy( |
|
|
|
action_list, dists |
|
|
|
) |
|
|
|
log_probs = ActionLogProbs.create(log_probs_list, self.action_spec) |
|
|
|
log_probs, entropies = self._get_probs_and_entropy(actions, dists) |
|
|
|
# Use the sum of entropy across actions, not the mean |
|
|
|
entropy_sum = torch.sum(entropies, dim=1) |
|
|
|
return log_probs, entropy_sum |
|
|
|
|
|
|
return torch.cat([dist.exported_model_output() for dist in dists], dim=1) |
|
|
|
out_list: List[torch.Tensor] = [] |
|
|
|
if self.action_spec.continuous_size > 0: |
|
|
|
out_list.append(dists.continuous.exported_model_output()) |
|
|
|
if self.action_spec.discrete_size > 0: |
|
|
|
for discrete_dist in dists.discrete: |
|
|
|
out_list.append(discrete_dist.exported_model_output()) |
|
|
|
return torch.cat(out_list, dim=1) |
|
|
|
action_list = self._sample_action(dists) |
|
|
|
log_probs_list, entropies, all_logs_list = ModelUtils.get_probs_and_entropy( |
|
|
|
action_list, dists |
|
|
|
) |
|
|
|
actions = AgentAction.create(action_list, self.action_spec) |
|
|
|
log_probs = ActionLogProbs.create( |
|
|
|
log_probs_list, self.action_spec, all_logs_list |
|
|
|
) |
|
|
|
actions = self._sample_action(dists) |
|
|
|
log_probs, entropies = self._get_probs_and_entropy(actions, dists) |
|
|
|
# Use the sum of entropy across actions, not the mean |
|
|
|
entropy_sum = torch.sum(entropies, dim=1) |
|
|
|
return (actions, log_probs, entropy_sum) |