|
|
|
|
|
|
and 1 for unmasked. |
|
|
|
""" |
|
|
|
unmasked_log_probs = self._create_policy_branches(logits, act_size) |
|
|
|
self._sampled_policy, self._all_probs, action_index = self._get_masked_actions_probs( |
|
|
|
unmasked_log_probs, act_size, action_masks |
|
|
|
) |
|
|
|
( |
|
|
|
self._sampled_policy, |
|
|
|
self._all_probs, |
|
|
|
action_index, |
|
|
|
) = self._get_masked_actions_probs(unmasked_log_probs, act_size, action_masks) |
|
|
|
self._entropy = self._create_entropy( |
|
|
|
self._sampled_onehot, self._all_probs, action_index, act_size |
|
|
|
) |
|
|
|
self._entropy = self._create_entropy(self._all_probs, action_index, act_size) |
|
|
|
self._total_prob = self._get_log_probs( |
|
|
|
self._sampled_onehot, self._all_probs, action_index, act_size |
|
|
|
) |
|
|
|
|
|
|
return log_probs |
|
|
|
|
|
|
|
def _create_entropy( |
|
|
|
self, |
|
|
|
all_log_probs: tf.Tensor, |
|
|
|
sample_onehot: tf.Tensor, |
|
|
|
action_idx: List[int], |
|
|
|
act_size: List[int], |
|
|
|
self, all_log_probs: tf.Tensor, action_idx: List[int], act_size: List[int] |
|
|
|
) -> tf.Tensor: |
|
|
|
entropy = tf.reduce_sum( |
|
|
|
( |
|
|
|