|
|
|
|
|
|
|
|
|
|
if self.use_bisim: |
|
|
|
if self.use_var_predict: |
|
|
|
predict_diff = self.policy.predict_distribution.w_distance(self.policy.bisim_predict_distribution) |
|
|
|
self.predict_diff = self.policy.predict_distribution.w_distance(self.policy.bisim_predict_distribution) |
|
|
|
predict_diff = tf.reduce_mean( |
|
|
|
tf.squared_difference(self.policy.bisim_predict, self.policy.predict) |
|
|
|
self.predict_diff = tf.reduce_sum( |
|
|
|
tf.squared_difference(self.policy.bisim_predict, self.policy.predict), axis=1 |
|
|
|
reward_diff = tf.reduce_mean( |
|
|
|
tf.squared_difference(self.policy.bisim_pred_reward, self.policy.pred_reward) |
|
|
|
self.reward_diff = tf.reduce_sum( |
|
|
|
tf.abs(self.policy.bisim_pred_reward - self.policy.pred_reward), axis=1 |
|
|
|
) |
|
|
|
bisim_diff = 0.99 * self.predict_diff + self.reward_diff |
|
|
|
self.encode_dist = tf.reduce_sum( |
|
|
|
tf.abs(self.policy.encoder - self.policy.bisim_encoder), axis=1 |
|
|
|
predict_diff = 0.99 * predict_diff + tf.abs(reward_diff) |
|
|
|
encode_dist = tf.reduce_mean( |
|
|
|
tf.abs(self.policy.encoder - self.policy.bisim_encoder) |
|
|
|
# tf.squared_difference(self.policy.encoder, self.policy.bisim_encoder) |
|
|
|
) |
|
|
|
self.encode_dist_val = encode_dist |
|
|
|
self.predict_diff_val = predict_diff |
|
|
|
self.bisim_loss = tf.squared_difference(encode_dist, predict_diff) |
|
|
|
self.bisim_loss = tf.reduce_mean(tf.squared_difference(self.encode_dist, bisim_diff)) |
|
|
|
|
|
|
|
self.loss = ( |
|
|
|
self.policy_loss |
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
update_vals = self._execute_model(feed_dict, self.bisim_update_dict) |
|
|
|
# print("model difference:", self.policy.sess.run(self.predict_diff_val, feed_dict=feed_dict)) |
|
|
|
# print("encoder distance:", self.policy.sess.run(self.encode_dist_val, feed_dict=feed_dict)) |
|
|
|
# print("model difference:", self.policy.sess.run(self.predict_diff, feed_dict=feed_dict)) |
|
|
|
# print("reward difference:", self.policy.sess.run(self.reward_diff, feed_dict=feed_dict)) |
|
|
|
# print("encoder distance:", self.policy.sess.run(self.encode_dist, feed_dict=feed_dict)) |
|
|
|
# print("bisim loss:", self.policy.sess.run(self.bisim_loss, feed_dict=feed_dict)) |
|
|
|
|
|
|
|
for stat_name, update_name in stats_needed.items(): |
|
|
|
if update_name in update_vals.keys(): |
|
|
|