|
|
|
|
|
|
:param encoded_state: Tensor corresponding to encoded current state. |
|
|
|
:param encoded_next_state: Tensor corresponding to encoded next state. |
|
|
|
""" |
|
|
|
|
|
|
|
if separate_train: |
|
|
|
encoded_state = tf.stop_gradient(encoded_state) |
|
|
|
|
|
|
|
|
|
|
|
if separate_train: |
|
|
|
hidden = tf.stop_gradient(hidden) |
|
|
|
|
|
|
|
for i in range(forward_layers): |
|
|
|
hidden = tf.layers.dense( |
|
|
|
|
|
|
separate_train: bool = False |
|
|
|
): |
|
|
|
|
|
|
|
if separate_train: |
|
|
|
encoded_state = tf.stop_gradient(encoded_state) |
|
|
|
|
|
|
|
if separate_train: |
|
|
|
hidden = tf.stop_gradient(hidden) |
|
|
|
|
|
|
|
for i in range(forward_layers): |
|
|
|
hidden = tf.layers.dense( |
|
|
|
hidden, |
|
|
|