浏览代码

calculating gradient norms

/sensitivity
Andrew Cohen 4 年前
当前提交
1e50c76e
共有 1 个文件被更改,包括 15 次插入0 次删除
  1. 15
      ml-agents/mlagents/trainers/ppo/optimizer.py

15
ml-agents/mlagents/trainers/ppo/optimizer.py


def _create_ppo_optimizer_ops(self):
self.tf_optimizer = self.create_optimizer_op(self.learning_rate)
self.grads = self.tf_optimizer.compute_gradients(self.loss)
self.sensitivity = self.tf_optimizer.compute_gradients(
self.policy.output, var_list=self.policy.vector_in
)
self.update_batch = self.tf_optimizer.minimize(self.loss)
@timed

update_vals = self._execute_model(feed_dict, self.update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]
print(
len(
np.mean(
self._execute_model(feed_dict, {"sensi": self.sensitivity})[
"sensi"
][0][0],
axis=0,
)
)
)
return update_stats
def _construct_feed_dict(

正在加载...
取消
保存