|
|
|
|
|
|
|
|
|
|
if self.use_bisim: |
|
|
|
if self.use_var_predict: |
|
|
|
predict_diff = self.policy.predict_distribution.w_distance( |
|
|
|
predict_diff = tf.reduce_mean(self.policy.predict_distribution.w_distance( |
|
|
|
) |
|
|
|
)) |
|
|
|
tf.squared_difference( |
|
|
|
self.policy.bisim_predict, self.policy.predict |
|
|
|
) |
|
|
|
tf.reduce_sum( |
|
|
|
tf.squared_difference( |
|
|
|
self.policy.bisim_predict, self.policy.predict |
|
|
|
), axis=1) |
|
|
|
tf.squared_difference( |
|
|
|
self.policy.bisim_pred_reward, self.policy.pred_reward |
|
|
|
tf.abs( |
|
|
|
self.policy.bisim_pred_reward - self.policy.pred_reward |
|
|
|
].gamma * predict_diff + tf.abs(reward_diff) |
|
|
|
encode_dist = tf.reduce_mean( |
|
|
|
tf.squared_difference(self.policy.encoder, self.policy.bisim_encoder) |
|
|
|
) |
|
|
|
].gamma * predict_diff + reward_diff |
|
|
|
encode_dist = tf.reduce_mean(tf.reduce_sum( |
|
|
|
# tf.squared_difference(self.policy.encoder, self.policy.bisim_encoder) |
|
|
|
tf.abs(self.policy.encoder - self.policy.bisim_encoder), axis=1 |
|
|
|
)) |
|
|
|
self.predict_difference = predict_diff |
|
|
|
self.reward_difference = reward_diff |
|
|
|
self.encode_difference = encode_dist |
|
|
|
self.bisim_loss = tf.squared_difference(encode_dist, predict_diff) |
|
|
|
|
|
|
|
self.loss = ( |
|
|
|
|
|
|
batch.shuffle(sequence_length=1) |
|
|
|
batch2 = copy.deepcopy(batch) |
|
|
|
bisim_stats = self.update_encoder(batch1, batch2) |
|
|
|
|
|
|
|
elif self.use_transfer and self.smart_transfer: |
|
|
|
if self.update_mode == "model": |
|
|
|
update_vals = self._execute_model(feed_dict, self.update_dict) |
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
update_vals = self._execute_model(feed_dict, self.bisim_update_dict) |
|
|
|
|
|
|
|
# print("predict:", self.policy.sess.run(self.predict_difference, feed_dict)) |
|
|
|
# print("reward:", self.policy.sess.run(self.reward_difference, feed_dict)) |
|
|
|
# print("encode:", self.policy.sess.run(self.encode_difference, feed_dict)) |
|
|
|
# print("bisim loss:", self.policy.sess.run(self.bisim_loss, feed_dict)) |
|
|
|
for stat_name, update_name in stats_needed.items(): |
|
|
|
if update_name in update_vals.keys(): |
|
|
|
update_stats[stat_name] = update_vals[update_name] |
|
|
|