|
|
|
|
|
|
reuse_encoder, |
|
|
|
) |
|
|
|
|
|
|
|
self.action_encoder = self._create_action_encoder( |
|
|
|
self.current_action, |
|
|
|
self.h_size, |
|
|
|
self.action_feature_size, |
|
|
|
action_layers, |
|
|
|
) |
|
|
|
self.action_encoder = self.current_action # self._create_action_encoder( |
|
|
|
# self.current_action, |
|
|
|
# self.h_size, |
|
|
|
# self.action_feature_size, |
|
|
|
# action_layers, |
|
|
|
# ) |
|
|
|
|
|
|
|
if not reuse_encoder: |
|
|
|
self.targ_encoder = tf.stop_gradient(self.targ_encoder) |
|
|
|
|
|
|
encoding_checkpoint = os.path.join(self.model_path, f"encoding.ckpt") |
|
|
|
encoding_saver.save(self.sess, encoding_checkpoint) |
|
|
|
|
|
|
|
action_vars = tf.get_collection( |
|
|
|
tf.GraphKeys.TRAINABLE_VARIABLES, "action_enc" |
|
|
|
) |
|
|
|
action_saver = tf.train.Saver(action_vars) |
|
|
|
action_checkpoint = os.path.join(self.model_path, f"action_enc.ckpt") |
|
|
|
action_saver.save(self.sess, action_checkpoint) |
|
|
|
# action_vars = tf.get_collection( |
|
|
|
# tf.GraphKeys.TRAINABLE_VARIABLES, "action_enc" |
|
|
|
# ) |
|
|
|
# action_saver = tf.train.Saver(action_vars) |
|
|
|
# action_checkpoint = os.path.join(self.model_path, f"action_enc.ckpt") |
|
|
|
# action_saver.save(self.sess, action_checkpoint) |
|
|
|
|
|
|
|
latent_vars = tf.get_collection( |
|
|
|
tf.GraphKeys.TRAINABLE_VARIABLES, "encoding/latent" |
|
|
|
|
|
|
:param encoded_state: Tensor corresponding to encoded current state. |
|
|
|
:param encoded_next_state: Tensor corresponding to encoded next state. |
|
|
|
""" |
|
|
|
if not self.transfer: |
|
|
|
encoded_state = tf.stop_gradient(encoded_state) |
|
|
|
|
|
|
|
if not self.transfer: |
|
|
|
hidden = tf.stop_gradient(hidden) |
|
|
|
|
|
|
|
for i in range(forward_layers): |
|
|
|
hidden = tf.layers.dense( |
|
|
|
hidden, |
|
|
|
|
|
|
forward_layers: int, |
|
|
|
separate_train: bool = False, |
|
|
|
): |
|
|
|
if not self.transfer: |
|
|
|
encoded_state = tf.stop_gradient(encoded_state) |
|
|
|
|
|
|
|
if not self.transfer: |
|
|
|
hidden = tf.stop_gradient(hidden) |
|
|
|
|
|
|
|
for i in range(forward_layers): |
|
|
|
hidden = tf.layers.dense( |
|
|
|