浏览代码

Re-fix scoping and add method to get all variables

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
cdd57468
共有 3 个文件被更改,包括 89 次插入69 次删除
  1. 143
      ml-agents/mlagents/trainers/common/nn_policy.py
  2. 4
      ml-agents/mlagents/trainers/sac/optimizer.py
  3. 11
      ml-agents/mlagents/trainers/tf_policy.py

143
ml-agents/mlagents/trainers/common/nn_policy.py


)
self.tanh_squash = tanh_squash
self.resample = resample
self.trainable_variables: List[tf.Variable] = []
def create_tf_graph(self):
def get_trainable_variables(self) -> List[tf.Variable]:
"""
Returns a List of the trainable variables in this policy. if create_tf_graph hasn't been called,
returns empty list.
"""
return self.trainable_variables
def create_tf_graph(self) -> None:
with tf.variable_scope("policy"):
self.create_input_placeholders()
if self.use_continuous_act:
self.create_cc_actor(
self.h_size,
self.num_layers,
self.vis_encode_type,
self.tanh_squash,
self.resample,
)
else:
self.create_dc_actor(
self.h_size, self.num_layers, self.vis_encode_type
)
self.create_input_placeholders()
if self.use_continuous_act:
self.create_cc_actor(
self.h_size,
self.num_layers,
self.vis_encode_type,
self.tanh_squash,
self.resample,
)
else:
self.create_dc_actor(self.h_size, self.num_layers, self.vis_encode_type)
self.trainable_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy"
)
self.inference_dict: Dict[str, tf.Tensor] = {
"action": self.output,

:param tanh_squash: Whether to use a tanh function, or a clipped output.
:param resample: Whether we are using the resampling trick to update the policy.
"""
hidden_stream = LearningModel.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
with tf.variable_scope("policy"):
hidden_stream = LearningModel.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
if self.use_recurrent:
self.memory_in = tf.placeholder(

hidden_stream,
self.memory_in,
self.sequence_length_ph,
name="lstm_policy",
name="policy/lstm",
)
self.memory_out = tf.identity(memory_policy_out, name="recurrent_out")

mu = tf.layers.dense(
hidden_policy,
self.act_size[0],
activation=None,
name="mu",
kernel_initializer=LearningModel.scaled_init(0.01),
reuse=tf.AUTO_REUSE,
)
with tf.variable_scope("policy"):
mu = tf.layers.dense(
hidden_policy,
self.act_size[0],
activation=None,
name="mu",
kernel_initializer=LearningModel.scaled_init(0.01),
reuse=tf.AUTO_REUSE,
)
# Policy-dependent log_sigma_sq
log_sigma = tf.layers.dense(
hidden_policy,
self.act_size[0],
activation=None,
name="log_std",
kernel_initializer=LearningModel.scaled_init(0.01),
)
# Policy-dependent log_sigma_sq
log_sigma = tf.layers.dense(
hidden_policy,
self.act_size[0],
activation=None,
name="log_std",
kernel_initializer=LearningModel.scaled_init(0.01),
)
log_sigma = tf.clip_by_value(log_sigma, LOG_STD_MIN, LOG_STD_MAX)
log_sigma = tf.clip_by_value(log_sigma, LOG_STD_MIN, LOG_STD_MAX)
sigma = tf.exp(log_sigma)
sigma = tf.exp(log_sigma)
epsilon = tf.random_normal(tf.shape(mu))
epsilon = tf.random_normal(tf.shape(mu))
sampled_policy = mu + sigma * epsilon
sampled_policy = mu + sigma * epsilon
# Stop gradient if we're not doing the resampling trick
if not resample:
sampled_policy_probs = tf.stop_gradient(sampled_policy)
else:
sampled_policy_probs = sampled_policy
# Stop gradient if we're not doing the resampling trick
if not resample:
sampled_policy_probs = tf.stop_gradient(sampled_policy)
else:
sampled_policy_probs = sampled_policy
# Compute probability of model output.
_gauss_pre = -0.5 * (
((sampled_policy_probs - mu) / (sigma + EPSILON)) ** 2
+ 2 * log_sigma
+ np.log(2 * np.pi)
)
all_probs = _gauss_pre
all_probs = tf.reduce_sum(_gauss_pre, axis=1, keepdims=True)
# Compute probability of model output.
_gauss_pre = -0.5 * (
((sampled_policy_probs - mu) / (sigma + EPSILON)) ** 2
+ 2 * log_sigma
+ np.log(2 * np.pi)
)
all_probs = _gauss_pre
all_probs = tf.reduce_sum(_gauss_pre, axis=1, keepdims=True)
if tanh_squash:
self.output_pre = tf.tanh(sampled_policy)

:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
"""
hidden_stream = LearningModel.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
with tf.variable_scope("policy"):
hidden_stream = LearningModel.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
if self.use_recurrent:
self.prev_action = tf.placeholder(

hidden_policy,
self.memory_in,
self.sequence_length_ph,
name="lstm_policy",
name="policy/lstm",
)
self.memory_out = tf.identity(memory_policy_out, "recurrent_out")

4
ml-agents/mlagents/trainers/sac/optimizer.py


LOGGER.debug("q_vars")
self.print_all_vars(self.policy_network.q_vars)
LOGGER.debug("policy_vars")
policy_vars = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy"
)
policy_vars = self.policy.get_trainable_variables()
self.print_all_vars(policy_vars)
self.target_init_op = [

11
ml-agents/mlagents/trainers/tf_policy.py


self.load = load
@abc.abstractmethod
def get_trainable_variables(self) -> List[tf.Variable]:
"""
Returns a List of the trainable variables in this policy. if create_tf_graph hasn't been called,
returns empty list.
"""
pass
@abc.abstractmethod
"""
Builds the tensorflow graph needed for this policy.
"""
pass
def _initialize_graph(self):

正在加载...
取消
保存