|
|
|
|
|
|
act_size: List[int], |
|
|
|
reparameterize: bool = False, |
|
|
|
tanh_squash: bool = False, |
|
|
|
condition_sigma: bool = True, |
|
|
|
log_sigma_min: float = -20, |
|
|
|
log_sigma_max: float = 2, |
|
|
|
): |
|
|
|
|
|
|
:param log_sigma_max: Maximum log standard deviation to clip by. |
|
|
|
""" |
|
|
|
encoded = self._create_mu_log_sigma( |
|
|
|
logits, act_size, log_sigma_min, log_sigma_max |
|
|
|
logits, |
|
|
|
act_size, |
|
|
|
log_sigma_min, |
|
|
|
log_sigma_max, |
|
|
|
condition_sigma=condition_sigma, |
|
|
|
) |
|
|
|
self._sampled_policy = self._create_sampled_policy(encoded) |
|
|
|
if not reparameterize: |
|
|
|
|
|
|
act_size: List[int], |
|
|
|
log_sigma_min: float, |
|
|
|
log_sigma_max: float, |
|
|
|
condition_sigma: bool, |
|
|
|
) -> "GaussianDistribution.MuSigmaTensors": |
|
|
|
|
|
|
|
mu = tf.layers.dense( |
|
|
|
|
|
|
reuse=tf.AUTO_REUSE, |
|
|
|
) |
|
|
|
|
|
|
|
# Policy-dependent log_sigma_sq |
|
|
|
log_sigma = tf.layers.dense( |
|
|
|
logits, |
|
|
|
act_size[0], |
|
|
|
activation=None, |
|
|
|
name="log_std", |
|
|
|
kernel_initializer=ModelUtils.scaled_init(0.01), |
|
|
|
) |
|
|
|
if condition_sigma: |
|
|
|
# Policy-dependent log_sigma_sq |
|
|
|
log_sigma = tf.layers.dense( |
|
|
|
logits, |
|
|
|
act_size[0], |
|
|
|
activation=None, |
|
|
|
name="log_std", |
|
|
|
kernel_initializer=ModelUtils.scaled_init(0.01), |
|
|
|
) |
|
|
|
else: |
|
|
|
log_sigma = tf.get_variable( |
|
|
|
"log_std", |
|
|
|
[act_size[0]], |
|
|
|
dtype=tf.float32, |
|
|
|
initializer=tf.zeros_initializer(), |
|
|
|
) |
|
|
|
log_sigma = tf.clip_by_value(log_sigma, log_sigma_min, log_sigma_max) |
|
|
|
sigma = tf.exp(log_sigma) |
|
|
|
return self.MuSigmaTensors(mu, log_sigma, sigma) |
|
|
|
|
|
|
""" |
|
|
|
Adjust probabilities for squashed sample before output |
|
|
|
""" |
|
|
|
probs -= tf.log(1 - squashed_policy ** 2 + EPSILON) |
|
|
|
return probs |
|
|
|
adjusted_probs = probs - tf.log(1 - squashed_policy ** 2 + EPSILON) |
|
|
|
return adjusted_probs |
|
|
|
|
|
|
|
@property |
|
|
|
def total_log_probs(self) -> tf.Tensor: |
|
|
|