浏览代码

try to fix the bisim metric

/develop/bisim-review
yanchaosun 4 年前
当前提交
f8b91faa
共有 2 个文件被更改,包括 25 次插入15 次删除
  1. 7
      ml-agents/mlagents/trainers/policy/transfer_policy.py
  2. 33
      ml-agents/mlagents/trainers/ppo_transfer/optimizer.py

7
ml-agents/mlagents/trainers/policy/transfer_policy.py


return kl
def w_distance(self, another):
return tf.reduce_sum(tf.squared_difference(self.mu, another.mu) + tf.squared_difference(
self.sigma, another.sigma
))
return tf.sqrt(
tf.reduce_sum(tf.squared_difference(self.mu, another.mu), axis=1) \
+ tf.reduce_sum(tf.squared_difference(self.sigma, another.sigma), axis=1)
)
class TransferPolicy(TFPolicy):

33
ml-agents/mlagents/trainers/ppo_transfer/optimizer.py


if self.use_bisim:
if self.use_var_predict:
predict_diff = self.policy.predict_distribution.w_distance(
predict_diff = tf.reduce_mean(self.policy.predict_distribution.w_distance(
)
))
tf.squared_difference(
self.policy.bisim_predict, self.policy.predict
)
tf.reduce_sum(
tf.squared_difference(
self.policy.bisim_predict, self.policy.predict
), axis=1)
tf.squared_difference(
self.policy.bisim_pred_reward, self.policy.pred_reward
tf.abs(
self.policy.bisim_pred_reward - self.policy.pred_reward
].gamma * predict_diff + tf.abs(reward_diff)
encode_dist = tf.reduce_mean(
tf.squared_difference(self.policy.encoder, self.policy.bisim_encoder)
)
].gamma * predict_diff + reward_diff
encode_dist = tf.reduce_mean(tf.reduce_sum(
# tf.squared_difference(self.policy.encoder, self.policy.bisim_encoder)
tf.abs(self.policy.encoder - self.policy.bisim_encoder), axis=1
))
self.predict_difference = predict_diff
self.reward_difference = reward_diff
self.encode_difference = encode_dist
self.bisim_loss = tf.squared_difference(encode_dist, predict_diff)
self.loss = (

batch.shuffle(sequence_length=1)
batch2 = copy.deepcopy(batch)
bisim_stats = self.update_encoder(batch1, batch2)
elif self.use_transfer and self.smart_transfer:
if self.update_mode == "model":
update_vals = self._execute_model(feed_dict, self.update_dict)

}
update_vals = self._execute_model(feed_dict, self.bisim_update_dict)
# print("predict:", self.policy.sess.run(self.predict_difference, feed_dict))
# print("reward:", self.policy.sess.run(self.reward_difference, feed_dict))
# print("encode:", self.policy.sess.run(self.encode_difference, feed_dict))
# print("bisim loss:", self.policy.sess.run(self.bisim_loss, feed_dict))
for stat_name, update_name in stats_needed.items():
if update_name in update_vals.keys():
update_stats[stat_name] = update_vals[update_name]

正在加载...
取消
保存