浏览代码

Split value and policy networks

/develop/separatevalue
Ervin Teng 5 年前
当前提交
35d73d1d
共有 1 个文件被更改,包括 23 次插入8 次删除
  1. 31
      ml-agents/mlagents/trainers/ppo/models.py

31
ml-agents/mlagents/trainers/ppo/models.py


:param num_layers: Number of hidden linear layers.
"""
hidden_streams = self.create_observation_streams(
1, h_size, num_layers, vis_encode_type
2, h_size, num_layers, vis_encode_type
hidden = hidden_streams[0]
if self.use_recurrent:
self.prev_action = tf.placeholder(

],
axis=1,
)
hidden = tf.concat([hidden, prev_action_oh], axis=1)
hidden_policy = tf.concat([hidden_streams[0], prev_action_oh], axis=1)
hidden, memory_out = self.create_recurrent_encoder(
hidden, self.memory_in, self.sequence_length
_half_point = int(self.m_size / 2)
hidden_policy, memory_policy_out = self.create_recurrent_encoder(
hidden_policy,
self.memory_in[:, :_half_point],
self.sequence_length,
name="lstm_policy",
)
hidden_value, memory_value_out = self.create_recurrent_encoder(
hidden_streams[1],
self.memory_in[:, _half_point:],
self.sequence_length,
name="lstm_value",
)
self.memory_out = tf.concat(
[memory_policy_out, memory_value_out], axis=1, name="recurrent_out"
self.memory_out = tf.identity(memory_out, name="recurrent_out")
else:
hidden_policy = hidden_streams[0]
hidden_value = hidden_streams[1]
hidden,
hidden_policy,
size,
activation=None,
use_bias=False,

self.output = tf.identity(output)
self.normalized_logits = tf.identity(normalized_logits, name="action")
self.create_value_heads(self.stream_names, hidden)
self.create_value_heads(self.stream_names, hidden_value)
self.action_holder = tf.placeholder(
shape=[None, len(policy_branches)], dtype=tf.int32, name="action_holder"

正在加载...
取消
保存