|
|
|
|
|
|
self.encoding_size = hidden_size |
|
|
|
self.continuous_act_size = continuous_act_size |
|
|
|
self.discrete_act_size = discrete_act_size |
|
|
|
self.continuous_distribution = None #: List[GaussianDistribution] = [] |
|
|
|
self.discrete_distribution = None #: List[MultiCategoricalDistribution] = [] |
|
|
|
self.continuous_distributions : List[GaussianDistribution] = [] |
|
|
|
self.discrete_distributions : List[MultiCategoricalDistribution] = [] |
|
|
|
self.continuous_distribution = GaussianDistribution( |
|
|
|
self.continuous_distributions.append(GaussianDistribution( |
|
|
|
) |
|
|
|
self.discrete_distribution = MultiCategoricalDistribution(self.encoding_size, discrete_act_size) |
|
|
|
self.discrete_distributions.append(MultiCategoricalDistribution(self.encoding_size, discrete_act_size)) |
|
|
|
|
|
|
|
|
|
|
|
def evaluate(self, inputs: torch.Tensor, masks: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
|
|
continuous_action_list = [continuous_actions[..., i] for i in range(continuous_actions.shape[-1])] |
|
|
|
continuous_log_probs, continuous_entropies, _ = ModelUtils.get_probs_and_entropy(continuous_action_list, continuous_dists) |
|
|
|
|
|
|
|
|
|
|
return torch.cat([dist.exported_model_output() for dist in dists], dim=1) |
|
|
|
|
|
|
|
def _get_dists(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[List[DistInstance], List[DiscreteDistInstance]]: |
|
|
|
#continuous_distributions: List[DistInstance] = [] |
|
|
|
#discrete_distributions: List[DiscreteDistInstance] = [] |
|
|
|
continuous_dist_instances = self.continuous_distribution(inputs)# for continuous_dist in self.continuous_distributions] |
|
|
|
discrete_dist_instances = self.discrete_distribution(inputs, masks)# for discrete_dist in self.discrete_distributions] |
|
|
|
#for continuous_dist in self.continuous_distributions: |
|
|
|
# continuous_distributions += continuous_dist(inputs) |
|
|
|
#for discrete_dist in self.discrete_distributions: |
|
|
|
# discrete_distributions += discrete_dist(inputs, masks) |
|
|
|
return continuous_dist_instances, discrete_dist_instances |
|
|
|
continuous_distributions: List[DistInstance] = [] |
|
|
|
discrete_distributions: List[DiscreteDistInstance] = [] |
|
|
|
#continuous_dist_instances = self.continuous_distribution(inputs)# for continuous_dist in self.continuous_distributions] |
|
|
|
#discrete_dist_instances = self.discrete_distribution(inputs, masks)# for discrete_dist in self.discrete_distributions] |
|
|
|
#return continuous_dist_instances, discrete_dist_instances |
|
|
|
for continuous_dist in self.continuous_distributions: |
|
|
|
continuous_distribution = continuous_dist(inputs) |
|
|
|
for cd in continuous_distribution: |
|
|
|
continuous_distributions.append(cd) |
|
|
|
for discrete_dist in self.discrete_distributions: |
|
|
|
discrete_distribution = discrete_dist(inputs, masks) |
|
|
|
for dd in discrete_distribution: |
|
|
|
discrete_distributions.append(dd) |
|
|
|
return continuous_distributions, discrete_distributions |
|
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
|
continuous_dists, discrete_dists = self._get_dists(inputs, masks) |
|
|
|