|
|
|
|
|
|
self.use_inverse_model = hyperparameters.use_inverse_model |
|
|
|
self.predict_return = hyperparameters.predict_return |
|
|
|
self.reuse_encoder = hyperparameters.reuse_encoder |
|
|
|
self.use_bisim = hyperparameters.use_bisim |
|
|
|
|
|
|
|
self.use_alter = hyperparameters.use_alter |
|
|
|
self.in_batch_alter = hyperparameters.in_batch_alter |
|
|
|
|
|
|
|
|
|
|
self.ppo_update_dict: Dict[str, tf.Tensor] = {} |
|
|
|
self.model_update_dict: Dict[str, tf.Tensor] = {} |
|
|
|
self.bisim_update_dict: Dict[str, tf.Tensor] = {} |
|
|
|
self.predict_return, self.use_inverse_model, self.reuse_encoder) |
|
|
|
self.predict_return, self.use_inverse_model, self.reuse_encoder, self.use_bisim) |
|
|
|
|
|
|
|
with policy.graph.as_default(): |
|
|
|
super().__init__(policy, trainer_params) |
|
|
|
|
|
|
self.stats_name_to_update_name.update({ |
|
|
|
"Losses/Reward Loss": "reward_loss", |
|
|
|
}) |
|
|
|
# if self.use_bisim: |
|
|
|
# self.stats_name_to_update_name.update({ |
|
|
|
# "Losses/Bisim Loss": "bisim_loss", |
|
|
|
# }) |
|
|
|
if self.policy.use_recurrent: |
|
|
|
self.m_size = self.policy.m_size |
|
|
|
self.memory_in = tf.placeholder( |
|
|
|
|
|
|
int(max_step), |
|
|
|
min_value=1e-10, |
|
|
|
) |
|
|
|
self.bisim_learning_rate = ModelUtils.create_schedule( |
|
|
|
ScheduleType.LINEAR, |
|
|
|
lr, |
|
|
|
self.policy.global_step, |
|
|
|
int(max_step), |
|
|
|
min_value=1e-10, |
|
|
|
) |
|
|
|
self._create_losses( |
|
|
|
self.policy.total_log_probs, |
|
|
|
self.old_log_probs, |
|
|
|
|
|
|
) |
|
|
|
self.returns_holders[name] = returns_holder |
|
|
|
self.old_values[name] = old_value |
|
|
|
|
|
|
|
self.advantage = tf.placeholder( |
|
|
|
shape=[None], dtype=tf.float32, name="advantages" |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
if self.use_inverse_model: |
|
|
|
self.model_loss += 0.5 * self.policy.inverse_loss |
|
|
|
# self.model_loss = 0.2 * self.policy.forward_loss + 0.8 * self.policy.inverse_loss |
|
|
|
|
|
|
|
if self.use_bisim: |
|
|
|
if self.use_var_predict: |
|
|
|
predict_diff = self.policy.predict_distribution.w_distance(self.policy.bisim_predict_distribution) |
|
|
|
else: |
|
|
|
predict_diff = tf.reduce_mean( |
|
|
|
tf.squared_difference(self.policy.bisim_predict, self.policy.predict) |
|
|
|
) |
|
|
|
if self.predict_return: |
|
|
|
reward_diff = tf.reduce_mean( |
|
|
|
tf.squared_difference(self.policy.bisim_pred_reward, self.policy.pred_reward) |
|
|
|
) |
|
|
|
predict_diff = self.reward_signals["extrinsic_value"].gamma * predict_diff + reward_diff |
|
|
|
encode_dist = tf.reduce_mean( |
|
|
|
tf.squared_difference(self.policy.encoder, self.policy.bisim_encoder) |
|
|
|
) |
|
|
|
self.bisim_loss = tf.squared_difference(encode_dist, predict_diff) |
|
|
|
|
|
|
|
self.loss = ( |
|
|
|
self.policy_loss |
|
|
|
+ self.model_loss |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
def _create_ppo_optimizer_ops(self): |
|
|
|
# if self.use_transfer: |
|
|
|
# if self.transfer_type == "dynamics": |
|
|
|
# if self.train_type == "all": |
|
|
|
# train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) |
|
|
|
# elif self.train_type == "encoding": |
|
|
|
# train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoding") |
|
|
|
# # train_vars += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "value") |
|
|
|
# elif self.train_type == "policy": |
|
|
|
# train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoding") |
|
|
|
# train_vars += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "policy") |
|
|
|
# train_vars += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "value") |
|
|
|
# print("trainable", train_vars) |
|
|
|
# # train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoding") |
|
|
|
# # train_vars += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "policy") |
|
|
|
# # train_vars += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "value") |
|
|
|
# # train_vars += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "policy/mu") |
|
|
|
# # train_vars += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "policy/log_std") |
|
|
|
# # train_vars += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "value/extrinsic_value") |
|
|
|
# elif self.transfer_type == "observation": |
|
|
|
# if self.train_type == "all": |
|
|
|
# train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) |
|
|
|
# elif self.train_type == "policy": |
|
|
|
# train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "policy") \ |
|
|
|
# + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "predict") \ |
|
|
|
# + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "inverse") \ |
|
|
|
# + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "value") |
|
|
|
# # + tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoding/latent") |
|
|
|
# else: |
|
|
|
# train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) |
|
|
|
train_vars = [] |
|
|
|
if self.train_encoder: |
|
|
|
train_vars += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoding") |
|
|
|
|
|
|
self.grads = self.tf_optimizer.compute_gradients(self.loss, var_list=train_vars) |
|
|
|
self.update_batch = self.tf_optimizer.minimize(self.loss, var_list=train_vars) |
|
|
|
|
|
|
|
if self.use_bisim: |
|
|
|
bisim_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoding") |
|
|
|
self.bisim_optimizer = self.create_optimizer_op(self.bisim_learning_rate) |
|
|
|
self.bisim_grads = self.tf_optimizer.compute_gradients(self.bisim_loss, var_list=bisim_train_vars) |
|
|
|
self.bisim_update_batch = self.tf_optimizer.minimize(self.bisim_loss, var_list=bisim_train_vars) |
|
|
|
self.bisim_update_dict.update( |
|
|
|
{ |
|
|
|
"bisim_loss": self.bisim_loss, |
|
|
|
"update_batch": self.bisim_update_batch, |
|
|
|
"bisim_learning_rate": self.bisim_learning_rate, |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
def _init_alter_update(self): |
|
|
|
|
|
|
|
|
|
|
if update_name in update_vals.keys(): |
|
|
|
update_stats[stat_name] = update_vals[update_name] |
|
|
|
|
|
|
|
self.num_updates += 1 |
|
|
|
return update_stats |
|
|
|
|
|
|
|
def update_encoder(self, mini_batch1: AgentBuffer, mini_batch2: AgentBuffer, mini_batch3: AgentBuffer): |
|
|
|
|
|
|
|
stats_needed = { |
|
|
|
"Losses/Bisim Loss": "bisim_loss", |
|
|
|
"Policy/Bisim Learning Rate": "bisim learning_rate", |
|
|
|
} |
|
|
|
update_stats = {} |
|
|
|
|
|
|
|
feed_dict = { |
|
|
|
self.policy.vector_in: mini_batch1["vector_in"], |
|
|
|
self.policy.vector_bisim: mini_batch2["vector_in"], |
|
|
|
self.policy.current_action: mini_batch3["actions"], |
|
|
|
} |
|
|
|
|
|
|
|
update_vals = self._execute_model(feed_dict, self.bisim_update_dict) |
|
|
|
|
|
|
|
for stat_name, update_name in stats_needed.items(): |
|
|
|
if update_name in update_vals.keys(): |
|
|
|
update_stats[stat_name] = update_vals[update_name] |
|
|
|
|
|
|
|
return update_stats |
|
|
|
|
|
|
|
def _construct_feed_dict( |
|
|
|
|
|
|
self.policy.mask_input: mini_batch["masks"] * burn_in_mask, |
|
|
|
self.advantage: mini_batch["advantages"], |
|
|
|
self.all_old_log_probs: mini_batch["action_probs"], |
|
|
|
self.policy.processed_vector_next: mini_batch["next_vector_in"], |
|
|
|
# self.policy.next_vector_in: mini_batch["next_vector_in"], |
|
|
|
self.policy.vector_next: mini_batch["next_vector_in"], |
|
|
|
self.policy.current_action: mini_batch["actions"], |
|
|
|
self.policy.current_reward: mini_batch["extrinsic_rewards"], |
|
|
|
# self.dis_returns: mini_batch["discounted_returns"] |
|
|
|
|
|
|
) |
|
|
|
# print(self.policy.sess.run(self.policy.encoder, feed_dict={self.policy.vector_in: mini_batch["vector_obs"]})) |
|
|
|
return feed_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_cc_critic_old( |
|
|
|
self, h_size: int, num_layers: int, vis_encode_type: EncoderType |
|
|
|