|
|
|
|
|
|
# reuse_encoder, |
|
|
|
# ) |
|
|
|
|
|
|
|
# self.action_encoder = self._create_action_encoder( |
|
|
|
# self.current_action, |
|
|
|
# self.h_size, |
|
|
|
# self.action_feature_size, |
|
|
|
# action_layers |
|
|
|
# ) |
|
|
|
self.action_encoder = self._create_action_encoder( |
|
|
|
self.current_action, |
|
|
|
self.h_size, |
|
|
|
self.action_feature_size, |
|
|
|
action_layers |
|
|
|
) |
|
|
|
|
|
|
|
if self.inverse_model: |
|
|
|
with tf.variable_scope("inverse"): |
|
|
|
|
|
|
|
|
|
|
self.predict, self.predict_distribution = self.create_forward_model( |
|
|
|
self.encoder, |
|
|
|
self.current_action, |
|
|
|
self.action_encoder, |
|
|
|
forward_layers, |
|
|
|
var_predict=var_predict, |
|
|
|
separate_train=separate_model_train |
|
|
|
|
|
|
self.targ_encoder, |
|
|
|
self.current_action, |
|
|
|
self.action_encoder, |
|
|
|
forward_layers, |
|
|
|
var_predict=var_predict, |
|
|
|
reuse=True, |
|
|
|
|
|
|
if predict_return: |
|
|
|
with tf.variable_scope("reward"): |
|
|
|
self.create_reward_model( |
|
|
|
self.encoder, self.current_action, forward_layers, separate_train=separate_model_train |
|
|
|
self.encoder, self.action_encoder, forward_layers, separate_train=separate_model_train |
|
|
|
) |
|
|
|
|
|
|
|
if self.use_bisim: |
|
|
|
|
|
|
num_layers: int, |
|
|
|
reuse: bool=False |
|
|
|
) -> tf.Tensor: |
|
|
|
|
|
|
|
if num_layers < 0: |
|
|
|
return self.current_action |
|
|
|
|
|
|
|
hidden_stream = ModelUtils.create_vector_observation_encoder( |
|
|
|
action, |
|
|
|
|
|
|
encoding_checkpoint = os.path.join(self.model_path, f"encoding.ckpt") |
|
|
|
encoding_saver.save(self.sess, encoding_checkpoint) |
|
|
|
|
|
|
|
# action_vars = tf.get_collection( |
|
|
|
# tf.GraphKeys.TRAINABLE_VARIABLES, "action_enc" |
|
|
|
# ) |
|
|
|
# action_saver = tf.train.Saver(action_vars) |
|
|
|
# action_checkpoint = os.path.join(self.model_path, f"action_enc.ckpt") |
|
|
|
# action_saver.save(self.sess, action_checkpoint) |
|
|
|
action_vars = tf.get_collection( |
|
|
|
tf.GraphKeys.TRAINABLE_VARIABLES, "action_enc" |
|
|
|
) |
|
|
|
if len(action_vars) > 0: |
|
|
|
action_saver = tf.train.Saver(action_vars) |
|
|
|
action_checkpoint = os.path.join(self.model_path, f"action_enc.ckpt") |
|
|
|
action_saver.save(self.sess, action_checkpoint) |
|
|
|
|
|
|
|
latent_vars = tf.get_collection( |
|
|
|
tf.GraphKeys.TRAINABLE_VARIABLES, "encoding/latent" |
|
|
|