浏览代码

using lists for distributions

/develop/hybrid-actions/distlist
Andrew Cohen 4 年前
当前提交
d6544353
共有 1 个文件被更改,包括 19 次插入14 次删除
  1. 33
      ml-agents/mlagents/trainers/torch/action_models.py

33
ml-agents/mlagents/trainers/torch/action_models.py


self.encoding_size = hidden_size
self.continuous_act_size = continuous_act_size
self.discrete_act_size = discrete_act_size
self.continuous_distribution = None #: List[GaussianDistribution] = []
self.discrete_distribution = None #: List[MultiCategoricalDistribution] = []
self.continuous_distributions : List[GaussianDistribution] = []
self.discrete_distributions : List[MultiCategoricalDistribution] = []
self.continuous_distribution = GaussianDistribution(
self.continuous_distributions.append(GaussianDistribution(
)
self.discrete_distribution = MultiCategoricalDistribution(self.encoding_size, discrete_act_size)
self.discrete_distributions.append(MultiCategoricalDistribution(self.encoding_size, discrete_act_size))
def evaluate(self, inputs: torch.Tensor, masks: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

continuous_action_list = [continuous_actions[..., i] for i in range(continuous_actions.shape[-1])]
continuous_log_probs, continuous_entropies, _ = ModelUtils.get_probs_and_entropy(continuous_action_list, continuous_dists)

return torch.cat([dist.exported_model_output() for dist in dists], dim=1)
def _get_dists(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[List[DistInstance], List[DiscreteDistInstance]]:
#continuous_distributions: List[DistInstance] = []
#discrete_distributions: List[DiscreteDistInstance] = []
continuous_dist_instances = self.continuous_distribution(inputs)# for continuous_dist in self.continuous_distributions]
discrete_dist_instances = self.discrete_distribution(inputs, masks)# for discrete_dist in self.discrete_distributions]
#for continuous_dist in self.continuous_distributions:
# continuous_distributions += continuous_dist(inputs)
#for discrete_dist in self.discrete_distributions:
# discrete_distributions += discrete_dist(inputs, masks)
return continuous_dist_instances, discrete_dist_instances
continuous_distributions: List[DistInstance] = []
discrete_distributions: List[DiscreteDistInstance] = []
#continuous_dist_instances = self.continuous_distribution(inputs)# for continuous_dist in self.continuous_distributions]
#discrete_dist_instances = self.discrete_distribution(inputs, masks)# for discrete_dist in self.discrete_distributions]
#return continuous_dist_instances, discrete_dist_instances
for continuous_dist in self.continuous_distributions:
continuous_distribution = continuous_dist(inputs)
for cd in continuous_distribution:
continuous_distributions.append(cd)
for discrete_dist in self.discrete_distributions:
discrete_distribution = discrete_dist(inputs, masks)
for dd in discrete_distribution:
discrete_distributions.append(dd)
return continuous_distributions, discrete_distributions
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
continuous_dists, discrete_dists = self._get_dists(inputs, masks)

正在加载...
取消
保存