浏览代码

Move encoder creation to separate function

/develop/nopreviousactions
Ervin Teng 5 年前
当前提交
242e2421
共有 1 个文件被更改,包括 37 次插入35 次删除
  1. 72
      ml-agents/mlagents/trainers/common/nn_policy.py

72
ml-agents/mlagents/trainers/common/nn_policy.py


return
self.create_input_placeholders()
encoded = self._create_encoder(
self.visual_in,
self.processed_vector_in,
self.h_size,
self.num_layers,
self.vis_encode_type,
)
self.h_size,
self.num_layers,
self.vis_encode_type,
encoded,
self._create_dc_actor(
self.h_size, self.num_layers, self.vis_encode_type
)
self._create_dc_actor(encoded)
self.trainable_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy"
)

run_out = self._execute_model(feed_dict, self.inference_dict)
return run_out
def _create_cc_actor(
def _create_encoder(
visual_in: List[tf.Tensor],
vector_in: tf.Tensor,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
) -> None:
) -> tf.Tensor:
Creates Continuous control actor-critic model.
Creates an encoder for visual and vector observations.
:param tanh_squash: Whether to use a tanh function, or a clipped output.
:param reparameterize: Whether we are using the resampling trick to update the policy.
:return: The hidden layer (tf.Tensor) after the encoder.
hidden_stream = ModelUtils.create_observation_streams(
encoded = ModelUtils.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,

)[0]
return encoded
def _create_cc_actor(
self,
encoded: tf.Tensor,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
) -> None:
"""
Creates Continuous control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
:param tanh_squash: Whether to use a tanh function, or a clipped output.
:param reparameterize: Whether we are using the resampling trick to update the policy.
"""
hidden_stream,
self.memory_in,
self.sequence_length_ph,
name="lstm_policy",
encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy"
hidden_policy = hidden_stream
hidden_policy = encoded
with tf.variable_scope("policy"):
mu = tf.layers.dense(

shape=[None, self.act_size[0]], dtype=tf.float32, name="action_holder"
)
def _create_dc_actor(
self, h_size: int, num_layers: int, vis_encode_type: EncoderType
) -> None:
def _create_dc_actor(self, encoded: tf.Tensor) -> None:
"""
Creates Discrete control actor-critic model.
:param h_size: Size of hidden linear layers.

with tf.variable_scope("policy"):
hidden_stream = ModelUtils.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
if self.use_recurrent:
self.prev_action = tf.placeholder(
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"

],
axis=1,
)
hidden_policy = tf.concat([hidden_stream, prev_action_oh], axis=1)
hidden_policy = tf.concat([encoded, prev_action_oh], axis=1)
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"

self.memory_out = tf.identity(memory_policy_out, "recurrent_out")
else:
hidden_policy = hidden_stream
hidden_policy = encoded
policy_branches = []
with tf.variable_scope("policy"):

正在加载...
取消
保存