浏览代码

all hybrid simple rl tests pass

/develop/actionmodel-csharp
Andrew Cohen 4 年前
当前提交
7750bccd
共有 2 个文件被更改,包括 4 次插入12 次删除
  1. 9
      ml-agents/mlagents/trainers/torch/action_model.py
  2. 7
      ml-agents/mlagents/trainers/torch/utils.py

9
ml-agents/mlagents/trainers/torch/action_model.py


import abc
import numpy as np
import math
from mlagents.trainers.torch.layers import linear_layer, Initialization
DiscreteDistInstance,
GaussianDistribution,
MultiCategoricalDistribution,
)

def _get_dists(
self, inputs: torch.Tensor, masks: torch.Tensor
) -> Tuple[List[DistInstance], List[DiscreteDistInstance]]:
) -> List[DistInstance]:
distribution_instances: List[DistInstance] = []
for distribution in self._distributions:
dist_instances = distribution(inputs, masks)

)
# Use the sum of entropy across actions, not the mean
entropy_sum = torch.sum(entropies, dim=1)
return (actions, log_probs, entropies)
return (actions, log_probs, entropy_sum)

7
ml-agents/mlagents/trainers/torch/utils.py


@staticmethod
def get_probs_and_entropy(
action_list: List[torch.Tensor], dists: List[DistInstance]
) -> Tuple[List[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]:
) -> Tuple[List[torch.Tensor], torch.Tensor, Optional[List[torch.Tensor]]]:
log_probs_list = []
all_probs_list = []
entropies_list = []

entropies_list.append(entropy)
if isinstance(action_dist, DiscreteDistInstance):
all_probs_list.append(action_dist.all_log_prob())
print(entropies_list)
entropies = torch.stack(entropies_list, dim=-1)
if not all_probs_list:
entropies = entropies.squeeze(-1)
entropies = torch.cat(entropies_list, dim=1)
return log_probs_list, entropies, all_probs_list
@staticmethod

正在加载...
取消
保存