浏览代码

[change] Separate action outputs into OutputDistributions object (#3514)

/asymm-envs
GitHub 5 年前
当前提交
7d954797
共有 8 个文件被更改,包括 489 次插入147 次删除
  1. 159
      ml-agents/mlagents/trainers/common/nn_policy.py
  2. 4
      ml-agents/mlagents/trainers/ppo/optimizer.py
  3. 4
      ml-agents/mlagents/trainers/sac/optimizer.py
  4. 10
      ml-agents/mlagents/trainers/tests/test_bcmodule.py
  5. 4
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  6. 5
      ml-agents/mlagents/trainers/tf_policy.py
  7. 310
      ml-agents/mlagents/trainers/common/distributions.py
  8. 140
      ml-agents/mlagents/trainers/tests/test_distributions.py

159
ml-agents/mlagents/trainers/common/nn_policy.py


import logging
import numpy as np
from typing import Any, Dict, Optional, List
from mlagents.tf_utils import tf

from mlagents.trainers.models import EncoderType
from mlagents.trainers.models import ModelUtils
from mlagents.trainers.tf_policy import TFPolicy
from mlagents.trainers.common.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,
)
logger = logging.getLogger("mlagents.trainers")

hidden_policy = encoded
with tf.variable_scope("policy"):
mu = tf.layers.dense(
distribution = GaussianDistribution(
self.act_size[0],
activation=None,
name="mu",
kernel_initializer=ModelUtils.scaled_init(0.01),
reuse=tf.AUTO_REUSE,
self.act_size,
reparameterize=reparameterize,
tanh_squash=tanh_squash,
# Policy-dependent log_sigma
if condition_sigma_on_obs:
log_sigma = tf.layers.dense(
hidden_policy,
self.act_size[0],
activation=None,
name="log_sigma",
kernel_initializer=ModelUtils.scaled_init(0.01),
)
else:
log_sigma = tf.get_variable(
"log_sigma",
[self.act_size[0]],
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
log_sigma = tf.clip_by_value(log_sigma, self.log_std_min, self.log_std_max)
sigma = tf.exp(log_sigma)
epsilon = tf.random_normal(tf.shape(mu))
sampled_policy = mu + sigma * epsilon
# Stop gradient if we're not doing the resampling trick
if not reparameterize:
sampled_policy_probs = tf.stop_gradient(sampled_policy)
else:
sampled_policy_probs = sampled_policy
# Compute probability of model output.
_gauss_pre = -0.5 * (
((sampled_policy_probs - mu) / (sigma + EPSILON)) ** 2
+ 2 * log_sigma
+ np.log(2 * np.pi)
)
all_probs = _gauss_pre
all_probs = tf.reduce_sum(_gauss_pre, axis=1, keepdims=True)
self.output_pre = tf.tanh(sampled_policy)
# Squash correction
all_probs -= tf.reduce_sum(
tf.log(1 - self.output_pre ** 2 + EPSILON), axis=1, keepdims=True
)
self.output_pre = distribution.sample
self.output_pre = sampled_policy
self.output_pre = distribution.sample
# Clip and scale output to ensure actions are always within [-1, 1] range.
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
self.output = tf.identity(output_post, name="action")

self.all_log_probs = tf.identity(all_probs, name="action_probs")
single_dim_entropy = 0.5 * tf.reduce_mean(
tf.log(2 * np.pi * np.e) + 2 * log_sigma
)
# Make entropy the right shape
self.entropy = tf.ones_like(tf.reshape(mu[:, 0], [-1])) * single_dim_entropy
self.all_log_probs = tf.identity(distribution.log_probs, name="action_probs")
self.entropy = distribution.entropy
self.log_probs = tf.reduce_sum(
(tf.identity(self.all_log_probs)), axis=1, keepdims=True
)
self.total_log_probs = distribution.total_log_probs
def _create_dc_actor(self, encoded: tf.Tensor) -> None:
"""

else:
hidden_policy = encoded
policy_branches = []
with tf.variable_scope("policy"):
for size in self.act_size:
policy_branches.append(
tf.layers.dense(
hidden_policy,
size,
activation=None,
use_bias=False,
kernel_initializer=ModelUtils.scaled_init(0.01),
)
)
raw_log_probs = tf.concat(policy_branches, axis=1, name="action_probs")
output, self.action_probs, normalized_logits = ModelUtils.create_discrete_action_masking_layer(
raw_log_probs, self.action_masks, self.act_size
)
self.output = tf.identity(output)
self.all_log_probs = tf.identity(normalized_logits, name="action")
self.action_oh = tf.concat(
[
tf.one_hot(self.output[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
self.selected_actions = tf.stop_gradient(self.action_oh)
action_idx = [0] + list(np.cumsum(self.act_size))
self.entropy = tf.reduce_sum(
(
tf.stack(
[
tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.nn.softmax(
self.all_log_probs[:, action_idx[i] : action_idx[i + 1]]
),
logits=self.all_log_probs[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
)
self.log_probs = tf.reduce_sum(
(
tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]],
logits=normalized_logits[
:, action_idx[i] : action_idx[i + 1]
],
)
for i in range(len(self.act_size))
],
axis=1,
)
),
axis=1,
keepdims=True,
)
with tf.variable_scope("policy"):
distribution = MultiCategoricalDistribution(
hidden_policy, self.act_size, self.action_masks
)
# It's important that we are able to feed_dict a value into this tensor to get the
# right one-hot encoding, so we can't do identity on it.
self.output = distribution.sample
self.all_log_probs = tf.identity(distribution.log_probs, name="action")
self.selected_actions = tf.stop_gradient(
distribution.sample_onehot
) # In discrete, these are onehot
self.entropy = distribution.entropy
self.total_log_probs = distribution.total_log_probs

4
ml-agents/mlagents/trainers/ppo/optimizer.py


lr_schedule, lr, self.policy.global_step, int(max_step)
)
self._create_losses(
self.policy.log_probs,
self.policy.total_log_probs,
self.old_log_probs,
self.value_heads,
self.policy.entropy,

tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=self.policy.action_oh[
labels=self.policy.selected_actions[
:, action_idx[i] : action_idx[i + 1]
],
logits=old_normalized_logits[

4
ml-agents/mlagents/trainers/sac/optimizer.py


if discrete:
# We need to break up the Q functions by branch, and update them individually.
branched_q1_stream = self._apply_as_branches(
self.policy.action_oh * q1_streams[name]
self.policy.selected_actions * q1_streams[name]
self.policy.action_oh * q2_streams[name]
self.policy.selected_actions * q2_streams[name]
)
# Reduce each branch into scalar

10
ml-agents/mlagents/trainers/tests/test_bcmodule.py


# Test with continuous control env and vector actions
@pytest.mark.parametrize("is_sac", [True, False], ids=["ppo", "sac"])
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_update(is_sac):
mock_brain = mb.create_mock_3dball_brain()
bc_module = create_bc_module(

# Test with constant pretraining learning rate
@pytest.mark.parametrize("is_sac", [True, False], ids=["ppo", "sac"])
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_constant_lr_update(is_sac):
trainer_config = ppo_dummy_config()
mock_brain = mb.create_mock_3dball_brain()

# Test with RNN
@pytest.mark.parametrize("is_sac", [True, False], ids=["ppo", "sac"])
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_rnn_update(is_sac):
mock_brain = mb.create_mock_3dball_brain()
bc_module = create_bc_module(

# Test with discrete control and visual observations
@pytest.mark.parametrize("is_sac", [True, False], ids=["ppo", "sac"])
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_dc_visual_update(is_sac):
mock_brain = mb.create_mock_banana_brain()
bc_module = create_bc_module(

# Test with discrete control, visual observations and RNN
@pytest.mark.parametrize("is_sac", [True, False], ids=["ppo", "sac"])
@pytest.mark.parametrize("is_sac", [True, False], ids=["sac", "ppo"])
def test_bcmodule_rnn_dc_update(is_sac):
mock_brain = mb.create_mock_banana_brain()
bc_module = create_bc_module(

4
ml-agents/mlagents/trainers/tests/test_simple_rl.py


num_layers: 1
time_horizon: 64
sequence_length: 64
summary_freq: 500
tau: 0.005
summary_freq: 100
tau: 0.01
use_recurrent: false
curiosity_enc_size: 128
demo_path: None

5
ml-agents/mlagents/trainers/tf_policy.py


self.update_normalization_op: Optional[tf.Operation] = None
self.value: Optional[tf.Tensor] = None
self.all_log_probs: tf.Tensor = None
self.log_probs: Optional[tf.Tensor] = None
self.total_log_probs: Optional[tf.Tensor] = None
self.action_oh: tf.Tensor = None
self.selected_actions: Optional[tf.Tensor] = None
self.selected_actions: tf.Tensor = None
self.action_masks: Optional[tf.Tensor] = None
self.prev_action: Optional[tf.Tensor] = None
self.memory_in: Optional[tf.Tensor] = None

310
ml-agents/mlagents/trainers/common/distributions.py


import abc
from typing import NamedTuple, List, Tuple
import numpy as np
from mlagents.tf_utils import tf
from mlagents.trainers.models import ModelUtils
EPSILON = 1e-6 # Small value to avoid divide by zero
class OutputDistribution(abc.ABC):
@abc.abstractproperty
def log_probs(self) -> tf.Tensor:
"""
Returns a Tensor that when evaluated, produces the per-action log probabilities of this distribution.
The shape of this Tensor should be equivalent to (batch_size x the number of actions) produced in sample.
"""
pass
@abc.abstractproperty
def total_log_probs(self) -> tf.Tensor:
"""
Returns a Tensor that when evaluated, produces the total log probability for a single sample.
The shape of this Tensor should be equivalent to (batch_size x 1) produced in sample.
"""
pass
@abc.abstractproperty
def sample(self) -> tf.Tensor:
"""
Returns a Tensor that when evaluated, produces a sample of this OutputDistribution.
"""
pass
@abc.abstractproperty
def entropy(self) -> tf.Tensor:
"""
Returns a Tensor that when evaluated, produces the entropy of this distribution.
"""
pass
class DiscreteOutputDistribution(OutputDistribution):
@abc.abstractproperty
def sample_onehot(self) -> tf.Tensor:
"""
Returns a one-hot version of the output.
"""
class GaussianDistribution(OutputDistribution):
"""
A Gaussian output distribution for continuous actions.
"""
class MuSigmaTensors(NamedTuple):
mu: tf.Tensor
log_sigma: tf.Tensor
sigma: tf.Tensor
def __init__(
self,
logits: tf.Tensor,
act_size: List[int],
reparameterize: bool = False,
tanh_squash: bool = False,
log_sigma_min: float = -20,
log_sigma_max: float = 2,
):
"""
A Gaussian output distribution for continuous actions.
:param logits: Hidden layer to use as the input to the Gaussian distribution.
:param act_size: List containing the number of continuous actions.
:param reparameterize: Whether or not to use the reparameterization trick (block gradients through
log probability calculation.)
:param tanh_squash: Squash the output using tanh, constraining it between -1 and 1.
From: Haarnoja et. al, https://arxiv.org/abs/1801.01290
:param log_sigma_min: Minimum log standard deviation to clip by.
:param log_sigma_max: Maximum log standard deviation to clip by.
"""
encoded = self._create_mu_log_sigma(
logits, act_size, log_sigma_min, log_sigma_max
)
self._sampled_policy = self._create_sampled_policy(encoded)
if not reparameterize:
_sampled_policy_probs = tf.stop_gradient(self._sampled_policy)
else:
_sampled_policy_probs = self._sampled_policy
self._all_probs = self._create_log_probs(_sampled_policy_probs, encoded)
if tanh_squash:
self._sampled_policy = tf.tanh(self._sampled_policy)
self._all_probs = self._do_squash_correction_for_tanh(
self._all_probs, self._sampled_policy
)
self._total_prob = tf.reduce_sum(self._all_probs, axis=1, keepdims=True)
self._entropy = self._create_entropy(encoded)
def _create_mu_log_sigma(
self,
logits: tf.Tensor,
act_size: List[int],
log_sigma_min: float,
log_sigma_max: float,
) -> "GaussianDistribution.MuSigmaTensors":
mu = tf.layers.dense(
logits,
act_size[0],
activation=None,
name="mu",
kernel_initializer=ModelUtils.scaled_init(0.01),
reuse=tf.AUTO_REUSE,
)
# Policy-dependent log_sigma_sq
log_sigma = tf.layers.dense(
logits,
act_size[0],
activation=None,
name="log_std",
kernel_initializer=ModelUtils.scaled_init(0.01),
)
log_sigma = tf.clip_by_value(log_sigma, log_sigma_min, log_sigma_max)
sigma = tf.exp(log_sigma)
return self.MuSigmaTensors(mu, log_sigma, sigma)
def _create_sampled_policy(
self, encoded: "GaussianDistribution.MuSigmaTensors"
) -> tf.Tensor:
epsilon = tf.random_normal(tf.shape(encoded.mu))
sampled_policy = encoded.mu + encoded.sigma * epsilon
return sampled_policy
def _create_log_probs(
self, sampled_policy: tf.Tensor, encoded: "GaussianDistribution.MuSigmaTensors"
) -> tf.Tensor:
_gauss_pre = -0.5 * (
((sampled_policy - encoded.mu) / (encoded.sigma + EPSILON)) ** 2
+ 2 * encoded.log_sigma
+ np.log(2 * np.pi)
)
return _gauss_pre
def _create_entropy(
self, encoded: "GaussianDistribution.MuSigmaTensors"
) -> tf.Tensor:
single_dim_entropy = 0.5 * tf.reduce_mean(
tf.log(2 * np.pi * np.e) + tf.square(encoded.log_sigma)
)
# Make entropy the right shape
return tf.ones_like(tf.reshape(encoded.mu[:, 0], [-1])) * single_dim_entropy
def _do_squash_correction_for_tanh(self, probs, squashed_policy):
"""
Adjust probabilities for squashed sample before output
"""
probs -= tf.log(1 - squashed_policy ** 2 + EPSILON)
return probs
@property
def total_log_probs(self) -> tf.Tensor:
return self._total_prob
@property
def log_probs(self) -> tf.Tensor:
return self._all_probs
@property
def sample(self) -> tf.Tensor:
return self._sampled_policy
@property
def entropy(self) -> tf.Tensor:
return self._entropy
class MultiCategoricalDistribution(DiscreteOutputDistribution):
"""
A categorical distribution for multi-branched discrete actions. Also supports action masking.
"""
def __init__(self, logits: tf.Tensor, act_size: List[int], action_masks: tf.Tensor):
"""
A categorical distribution for multi-branched discrete actions.
:param logits: Hidden layer to use as the input to the Gaussian distribution.
:param act_size: List containing the number of discrete actions per branch.
:param action_masks: Tensor representing action masks. Should be of length sum(act_size), and 0 for masked
and 1 for unmasked.
"""
unmasked_log_probs = self._create_policy_branches(logits, act_size)
self._sampled_policy, self._all_probs, action_index = self._get_masked_actions_probs(
unmasked_log_probs, act_size, action_masks
)
self._sampled_onehot = self._action_onehot(self._sampled_policy, act_size)
self._entropy = self._create_entropy(
self._sampled_onehot, self._all_probs, action_index, act_size
)
self._total_prob = self._get_log_probs(
self._sampled_onehot, self._all_probs, action_index, act_size
)
def _create_policy_branches(
self, logits: tf.Tensor, act_size: List[int]
) -> List[tf.Tensor]:
policy_branches = []
for size in act_size:
policy_branches.append(
tf.layers.dense(
logits,
size,
activation=None,
use_bias=False,
kernel_initializer=ModelUtils.scaled_init(0.01),
)
)
unmasked_log_probs = tf.concat(policy_branches, axis=1)
return unmasked_log_probs
def _get_masked_actions_probs(
self,
unmasked_log_probs: tf.Tensor,
act_size: List[int],
action_masks: tf.Tensor,
) -> Tuple[tf.Tensor, tf.Tensor, np.ndarray]:
output, _, all_log_probs = ModelUtils.create_discrete_action_masking_layer(
unmasked_log_probs, action_masks, act_size
)
action_idx = [0] + list(np.cumsum(act_size))
return output, all_log_probs, action_idx
def _action_onehot(self, sample: tf.Tensor, act_size: List[int]) -> tf.Tensor:
action_oh = tf.concat(
[tf.one_hot(sample[:, i], act_size[i]) for i in range(len(act_size))],
axis=1,
)
return action_oh
def _get_log_probs(
self,
sample_onehot: tf.Tensor,
all_log_probs: tf.Tensor,
action_idx: List[int],
act_size: List[int],
) -> tf.Tensor:
log_probs = tf.reduce_sum(
(
tf.stack(
[
-tf.nn.softmax_cross_entropy_with_logits_v2(
labels=sample_onehot[:, action_idx[i] : action_idx[i + 1]],
logits=all_log_probs[:, action_idx[i] : action_idx[i + 1]],
)
for i in range(len(act_size))
],
axis=1,
)
),
axis=1,
keepdims=True,
)
return log_probs
def _create_entropy(
self,
all_log_probs: tf.Tensor,
sample_onehot: tf.Tensor,
action_idx: List[int],
act_size: List[int],
) -> tf.Tensor:
entropy = tf.reduce_sum(
(
tf.stack(
[
tf.nn.softmax_cross_entropy_with_logits_v2(
labels=tf.nn.softmax(
all_log_probs[:, action_idx[i] : action_idx[i + 1]]
),
logits=all_log_probs[:, action_idx[i] : action_idx[i + 1]],
)
for i in range(len(act_size))
],
axis=1,
)
),
axis=1,
)
return entropy
@property
def log_probs(self) -> tf.Tensor:
return self._all_probs
@property
def total_log_probs(self) -> tf.Tensor:
return self._total_prob
@property
def sample(self) -> tf.Tensor:
return self._sampled_policy
@property
def sample_onehot(self) -> tf.Tensor:
return self._sampled_onehot
@property
def entropy(self) -> tf.Tensor:
return self._entropy

140
ml-agents/mlagents/trainers/tests/test_distributions.py


import pytest
from mlagents.tf_utils import tf
import yaml
from mlagents.trainers.common.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,
)
@pytest.fixture
def dummy_config():
return yaml.safe_load(
"""
trainer: ppo
batch_size: 32
beta: 5.0e-3
buffer_size: 512
epsilon: 0.2
hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4
max_steps: 5.0e4
normalize: true
num_epoch: 5
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 1000
use_recurrent: false
normalize: true
memory_size: 8
curiosity_strength: 0.0
curiosity_enc_size: 1
summary_path: test
model_path: test
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)
VECTOR_ACTION_SPACE = [2]
VECTOR_OBS_SPACE = 8
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 32
NUM_AGENTS = 12
def test_gaussian_distribution():
with tf.Graph().as_default():
logits = tf.Variable(initial_value=[[0, 0]], trainable=True, dtype=tf.float32)
distribution = GaussianDistribution(
logits,
act_size=VECTOR_ACTION_SPACE,
reparameterize=False,
tanh_squash=False,
)
sess = tf.Session()
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
output = sess.run(distribution.sample)
for _ in range(10):
output = sess.run([distribution.sample, distribution.log_probs])
for out in output:
assert out.shape[1] == VECTOR_ACTION_SPACE[0]
output = sess.run([distribution.total_log_probs])
assert output[0].shape[0] == 1
def test_tanh_distribution():
with tf.Graph().as_default():
logits = tf.Variable(initial_value=[[0, 0]], trainable=True, dtype=tf.float32)
distribution = GaussianDistribution(
logits, act_size=VECTOR_ACTION_SPACE, reparameterize=False, tanh_squash=True
)
sess = tf.Session()
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
output = sess.run(distribution.sample)
for _ in range(10):
output = sess.run([distribution.sample, distribution.log_probs])
for out in output:
assert out.shape[1] == VECTOR_ACTION_SPACE[0]
# Assert action never exceeds [-1,1]
action = output[0][0]
for act in action:
assert act >= -1 and act <= 1
output = sess.run([distribution.total_log_probs])
assert output[0].shape[0] == 1
def test_multicategorical_distribution():
with tf.Graph().as_default():
logits = tf.Variable(initial_value=[[0, 0]], trainable=True, dtype=tf.float32)
action_masks = tf.Variable(
initial_value=[[1 for _ in range(sum(DISCRETE_ACTION_SPACE))]],
trainable=True,
dtype=tf.float32,
)
distribution = MultiCategoricalDistribution(
logits, act_size=DISCRETE_ACTION_SPACE, action_masks=action_masks
)
sess = tf.Session()
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
output = sess.run(distribution.sample)
for _ in range(10):
sample, log_probs = sess.run(
[distribution.sample, distribution.log_probs]
)
assert len(log_probs[0]) == sum(DISCRETE_ACTION_SPACE)
# Assert action never exceeds [-1,1]
assert len(sample[0]) == len(DISCRETE_ACTION_SPACE)
for i, act in enumerate(sample[0]):
assert act >= 0 and act <= DISCRETE_ACTION_SPACE[i]
output = sess.run([distribution.total_log_probs])
assert output[0].shape[0] == 1
# Test masks
mask = []
for space in DISCRETE_ACTION_SPACE:
mask.append(1)
for _action_space in range(1, space):
mask.append(0)
for _ in range(10):
sample, log_probs = sess.run(
[distribution.sample, distribution.log_probs],
feed_dict={action_masks: [mask]},
)
for act in sample[0]:
assert act >= 0 and act <= 1
output = sess.run([distribution.total_log_probs])
正在加载...
取消
保存