浏览代码

sensitivity

/sensitivity
Andrew Cohen 5 年前
当前提交
0e965a4d
共有 2 个文件被更改,包括 11 次插入14 次删除
  1. 22
      ml-agents/mlagents/trainers/ppo/optimizer.py
  2. 3
      ml-agents/mlagents/trainers/ppo/trainer.py

22
ml-agents/mlagents/trainers/ppo/optimizer.py


def _create_ppo_optimizer_ops(self):
self.tf_optimizer = self.create_optimizer_op(self.learning_rate)
self.grads = self.tf_optimizer.compute_gradients(self.loss)
self.sensitivity = self.tf_optimizer.compute_gradients(
self.policy.output, var_list=self.policy.vector_in
self.sensitivity = tf.reduce_mean(
tf.square(tf.gradients(self.policy.output, self.policy.vector_in)), axis=1
)
self.update_batch = self.tf_optimizer.minimize(self.loss)

update_vals = self._execute_model(feed_dict, self.update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]
print(
len(
np.mean(
self._execute_model(feed_dict, {"sensi": self.sensitivity})[
"sensi"
][0][0],
axis=0,
)
)
)
return update_stats
return update_stats
def compute_input_sensitivity(self, batch: AgentBuffer, num_sequences: int) -> None:
feed_dict = self._construct_feed_dict(batch, num_sequences)
sens = self._execute_model(feed_dict, {"sensi": self.sensitivity})["sensi"][0]
for obs, grad in sorted(enumerate(sens), reverse=True, key=lambda x: x[1]):
print(obs, grad)
def _construct_feed_dict(
self, mini_batch: AgentBuffer, num_sequences: int

3
ml-agents/mlagents/trainers/ppo/trainer.py


)
num_epoch = self.trainer_parameters["num_epoch"]
batch_update_stats = defaultdict(list)
self.optimizer.compute_input_sensitivity(
self.update_buffer, self.policy.sequence_length
)
for _ in range(num_epoch):
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length)
buffer = self.update_buffer

正在加载...
取消
保存