浏览代码

fix test torch distributions

/develop/actionmodel-csharp
Andrew Cohen 4 年前
当前提交
701c1a3f
共有 3 个文件被更改,包括 3 次插入3 次删除
  1. 2
      ml-agents/mlagents/trainers/tests/torch/test_distributions.py
  2. 2
      ml-agents/mlagents/trainers/torch/action_model.py
  3. 2
      ml-agents/mlagents/trainers/torch/distributions.py

2
ml-agents/mlagents/trainers/tests/torch/test_distributions.py


optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
for _ in range(50):
dist_inst = gauss_dist(sample_embedding)[0]
dist_inst = gauss_dist(sample_embedding)
if tanh_squash:
assert isinstance(dist_inst, TanhGaussianDistInstance)
else:

2
ml-agents/mlagents/trainers/torch/action_model.py


discrete_dist: Optional[List[DiscreteDistInstance]] = None
# This checks None because mypy complains otherwise
if self._continuous_distribution is not None:
continuous_dist = self._continuous_distribution(inputs, masks)
continuous_dist = self._continuous_distribution(inputs)
if self._discrete_distribution is not None:
discrete_dist = self._discrete_distribution(inputs, masks)
return DistInstances(continuous_dist, discrete_dist)

2
ml-agents/mlagents/trainers/torch/distributions.py


torch.zeros(1, num_outputs, requires_grad=True)
)
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> List[DistInstance]:
def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
mu = self.mu(inputs)
if self.conditional_sigma:
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)

正在加载...
取消
保存