浏览代码

[bug-fix] Fix entropy computation in MultiCategorialDistribution (#3607)

/bug-failed-api-check
GitHub 5 年前
当前提交
ed2eb6ef
共有 2 个文件被更改,包括 11 次插入13 次删除
  1. 18
      ml-agents/mlagents/trainers/distributions.py
  2. 6
      ml-agents/mlagents/trainers/tests/test_distributions.py

18
ml-agents/mlagents/trainers/distributions.py


and 1 for unmasked.
"""
unmasked_log_probs = self._create_policy_branches(logits, act_size)
self._sampled_policy, self._all_probs, action_index = self._get_masked_actions_probs(
unmasked_log_probs, act_size, action_masks
)
(
self._sampled_policy,
self._all_probs,
action_index,
) = self._get_masked_actions_probs(unmasked_log_probs, act_size, action_masks)
self._entropy = self._create_entropy(
self._sampled_onehot, self._all_probs, action_index, act_size
)
self._entropy = self._create_entropy(self._all_probs, action_index, act_size)
self._total_prob = self._get_log_probs(
self._sampled_onehot, self._all_probs, action_index, act_size
)

return log_probs
def _create_entropy(
self,
all_log_probs: tf.Tensor,
sample_onehot: tf.Tensor,
action_idx: List[int],
act_size: List[int],
self, all_log_probs: tf.Tensor, action_idx: List[int], act_size: List[int]
) -> tf.Tensor:
entropy = tf.reduce_sum(
(

6
ml-agents/mlagents/trainers/tests/test_distributions.py


sess.run(init)
output = sess.run(distribution.sample)
for _ in range(10):
sample, log_probs = sess.run(
[distribution.sample, distribution.log_probs]
sample, log_probs, entropy = sess.run(
[distribution.sample, distribution.log_probs, distribution.entropy]
)
assert len(log_probs[0]) == sum(DISCRETE_ACTION_SPACE)
# Assert action never exceeds [-1,1]

output = sess.run([distribution.total_log_probs])
assert output[0].shape[0] == 1
# Make sure entropy is correct
assert entropy[0] > 3.8
# Test masks
mask = []

正在加载...
取消
保存