|
|
|
|
|
|
:param num_layers: Number of hidden linear layers. |
|
|
|
""" |
|
|
|
hidden_streams = self.create_observation_streams( |
|
|
|
1, h_size, num_layers, vis_encode_type |
|
|
|
2, h_size, num_layers, vis_encode_type |
|
|
|
hidden = hidden_streams[0] |
|
|
|
|
|
|
|
if self.use_recurrent: |
|
|
|
self.prev_action = tf.placeholder( |
|
|
|
|
|
|
], |
|
|
|
axis=1, |
|
|
|
) |
|
|
|
hidden = tf.concat([hidden, prev_action_oh], axis=1) |
|
|
|
hidden_policy = tf.concat([hidden_streams[0], prev_action_oh], axis=1) |
|
|
|
hidden, memory_out = self.create_recurrent_encoder( |
|
|
|
hidden, self.memory_in, self.sequence_length |
|
|
|
_half_point = int(self.m_size / 2) |
|
|
|
hidden_policy, memory_policy_out = self.create_recurrent_encoder( |
|
|
|
hidden_policy, |
|
|
|
self.memory_in[:, :_half_point], |
|
|
|
self.sequence_length, |
|
|
|
name="lstm_policy", |
|
|
|
) |
|
|
|
|
|
|
|
hidden_value, memory_value_out = self.create_recurrent_encoder( |
|
|
|
hidden_streams[1], |
|
|
|
self.memory_in[:, _half_point:], |
|
|
|
self.sequence_length, |
|
|
|
name="lstm_value", |
|
|
|
) |
|
|
|
self.memory_out = tf.concat( |
|
|
|
[memory_policy_out, memory_value_out], axis=1, name="recurrent_out" |
|
|
|
self.memory_out = tf.identity(memory_out, name="recurrent_out") |
|
|
|
else: |
|
|
|
hidden_policy = hidden_streams[0] |
|
|
|
hidden_value = hidden_streams[1] |
|
|
|
hidden, |
|
|
|
hidden_policy, |
|
|
|
size, |
|
|
|
activation=None, |
|
|
|
use_bias=False, |
|
|
|
|
|
|
self.output = tf.identity(output) |
|
|
|
self.normalized_logits = tf.identity(normalized_logits, name="action") |
|
|
|
|
|
|
|
self.create_value_heads(self.stream_names, hidden) |
|
|
|
self.create_value_heads(self.stream_names, hidden_value) |
|
|
|
|
|
|
|
self.action_holder = tf.placeholder( |
|
|
|
shape=[None, len(policy_branches)], dtype=tf.int32, name="action_holder" |
|
|
|