浏览代码

_

/exp-tanh
vincentpierre 4 年前
当前提交
5f9ea5ea
共有 2 个文件被更改,包括 15 次插入11 次删除
  1. 2
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  2. 24
      ml-agents/mlagents/trainers/torch/distributions.py

2
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


value_loss = self.ppo_value_loss(
values, old_values, returns, decay_eps, loss_masks
)
print(log_probs)
# print(log_probs)
policy_loss = self.ppo_policy_loss(
ModelUtils.list_to_tensor(batch["advantages"]),
log_probs,

24
ml-agents/mlagents/trainers/torch/distributions.py


squashed = self.transform(unsquashed_sample)
return squashed
def _inverse_tanh(self, value):
# capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
capped_value = (1-EPSILON) * value
return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)
# def _inverse_tanh(self, value):
# # capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
# capped_value = (1-EPSILON) * value
# return 0.5 * torch.log((1 + capped_value) / (1 - capped_value) + EPSILON)
# capped_value = torch.clamp(value, -1 + EPSILON, 1 - EPSILON)
capped_value = 0.1 * value
unsquashed = self.transform.inv(capped_value)
unsquashed = self.transform.inv(value)
# unsquashed = self.transform.inv(value * 0.85)
print("tmp decomposition", value, capped_value, unsquashed, super().log_prob(unsquashed) , self.transform.log_abs_det_jacobian(
unsquashed, None
))
# print("tmp decomposition", value, capped_value, unsquashed, super().log_prob(unsquashed) , self.transform.log_abs_det_jacobian(
# unsquashed, None
# ))
if torch.isnan(torch.mean(value)):
print("Nan in log_prob(self, value), value")
if torch.isnan(torch.mean(super().log_prob(unsquashed))):

return tmp
def exported_model_output(self):
return self.sample()
class CategoricalDistInstance(DiscreteDistInstance):

正在加载...
取消
保存