|
|
|
|
|
|
self.prev_action = tf.placeholder( |
|
|
|
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action" |
|
|
|
) |
|
|
|
prev_action_oh = tf.concat( |
|
|
|
[ |
|
|
|
tf.one_hot(self.prev_action[:, i], self.act_size[i]) |
|
|
|
for i in range(len(self.act_size)) |
|
|
|
], |
|
|
|
axis=1, |
|
|
|
) |
|
|
|
hidden_policy = tf.concat([encoded, prev_action_oh], axis=1) |
|
|
|
|
|
|
|
hidden_policy, |
|
|
|
self.memory_in, |
|
|
|
self.sequence_length_ph, |
|
|
|
name="lstm_policy", |
|
|
|
encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy" |
|
|
|
) |
|
|
|
|
|
|
|
self.memory_out = tf.identity(memory_policy_out, "recurrent_out") |
|
|
|