浏览代码

[bug-fix] Improve performance for PPO with continuous actions (#3662)

/bug-failed-api-check
GitHub 5 年前
当前提交
29f82921
共有 3 个文件被更改,包括 29 次插入13 次删除
  1. 36
      ml-agents/mlagents/trainers/distributions.py
  2. 1
      ml-agents/mlagents/trainers/policy/nn_policy.py
  3. 5
      ml-agents/mlagents/trainers/tests/test_simple_rl.py

36
ml-agents/mlagents/trainers/distributions.py


act_size: List[int],
reparameterize: bool = False,
tanh_squash: bool = False,
condition_sigma: bool = True,
log_sigma_min: float = -20,
log_sigma_max: float = 2,
):

:param log_sigma_max: Maximum log standard deviation to clip by.
"""
encoded = self._create_mu_log_sigma(
logits, act_size, log_sigma_min, log_sigma_max
logits,
act_size,
log_sigma_min,
log_sigma_max,
condition_sigma=condition_sigma,
)
self._sampled_policy = self._create_sampled_policy(encoded)
if not reparameterize:

act_size: List[int],
log_sigma_min: float,
log_sigma_max: float,
condition_sigma: bool,
) -> "GaussianDistribution.MuSigmaTensors":
mu = tf.layers.dense(

reuse=tf.AUTO_REUSE,
)
# Policy-dependent log_sigma_sq
log_sigma = tf.layers.dense(
logits,
act_size[0],
activation=None,
name="log_std",
kernel_initializer=ModelUtils.scaled_init(0.01),
)
if condition_sigma:
# Policy-dependent log_sigma_sq
log_sigma = tf.layers.dense(
logits,
act_size[0],
activation=None,
name="log_std",
kernel_initializer=ModelUtils.scaled_init(0.01),
)
else:
log_sigma = tf.get_variable(
"log_std",
[act_size[0]],
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
log_sigma = tf.clip_by_value(log_sigma, log_sigma_min, log_sigma_max)
sigma = tf.exp(log_sigma)
return self.MuSigmaTensors(mu, log_sigma, sigma)

"""
Adjust probabilities for squashed sample before output
"""
probs -= tf.log(1 - squashed_policy ** 2 + EPSILON)
return probs
adjusted_probs = probs - tf.log(1 - squashed_policy ** 2 + EPSILON)
return adjusted_probs
@property
def total_log_probs(self) -> tf.Tensor:

1
ml-agents/mlagents/trainers/policy/nn_policy.py


self.act_size,
reparameterize=reparameterize,
tanh_squash=tanh_squash,
condition_sigma=condition_sigma_on_obs,
)
if tanh_squash:

5
ml-agents/mlagents/trainers/tests/test_simple_rl.py


def test_recurrent_ppo(use_discrete):
env = Memory1DEnvironment([BRAIN_NAME], use_discrete=use_discrete)
override_vals = {
"max_steps": 3000,
"max_steps": 4000,
"learning_rate": 1e-3,
_check_environment_trains(env, config)
_check_environment_trains(env, config, success_threshold=0.9)
@pytest.mark.parametrize("use_discrete", [True, False])

正在加载...
取消
保存