|
|
|
|
|
|
self.trainable_variables = tf.get_collection( |
|
|
|
tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy" |
|
|
|
) |
|
|
|
self.trainable_variables += tf.get_collection( |
|
|
|
tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm" |
|
|
|
) # LSTMs need to be root scope for Barracuda export |
|
|
|
|
|
|
|
self.inference_dict: Dict[str, tf.Tensor] = { |
|
|
|
"action": self.output, |
|
|
|
|
|
|
hidden_stream, |
|
|
|
self.memory_in, |
|
|
|
self.sequence_length_ph, |
|
|
|
name="policy/lstm", |
|
|
|
name="lstm_policy", |
|
|
|
) |
|
|
|
|
|
|
|
self.memory_out = tf.identity(memory_policy_out, name="recurrent_out") |
|
|
|
|
|
|
hidden_policy, |
|
|
|
self.memory_in, |
|
|
|
self.sequence_length_ph, |
|
|
|
name="policy/lstm", |
|
|
|
name="lstm_policy", |
|
|
|
) |
|
|
|
|
|
|
|
self.memory_out = tf.identity(memory_policy_out, "recurrent_out") |
|
|
|