浏览代码

write to csv

/sensitivity
Andrew Cohen 5 年前
当前提交
61aa9915
共有 2 个文件被更改,包括 6 次插入3 次删除
  1. 7
      ml-agents/mlagents/trainers/ppo/optimizer.py
  2. 2
      ml-agents/mlagents/trainers/ppo/trainer.py

7
ml-agents/mlagents/trainers/ppo/optimizer.py


from mlagents.trainers.stats import StatsSummary
class PPOOptimizer(TFOptimizer):
def __init__(self, policy: TFPolicy, trainer_params: Dict[str, Any]):
"""

update_stats[stat_name] = update_vals[update_name]
return update_stats
def compute_input_sensitivity(self, batch: AgentBuffer, num_sequences: int) -> Dict[int, float]:
def compute_input_sensitivity(
self, batch: AgentBuffer, num_sequences: int
) -> Dict[int, float]:
out = dict((obs,StatsSummary(grad, 0.0, 0.0)) for obs, grad in enumerate(sens))
out = dict((obs, StatsSummary(grad, 0.0, 0.0)) for obs, grad in enumerate(sens))
for obs, grad in sorted(enumerate(sens), reverse=True, key=lambda x: x[1]):
print("Observation {} has relevance {}".format(obs, grad))
return out

2
ml-agents/mlagents/trainers/ppo/trainer.py


self.update_buffer, self.policy.sequence_length
)
self.csv_writer.write_stats("sensitivity", sensitivities, self.step)
for _ in range(num_epoch):
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length)
buffer = self.update_buffer

正在加载...
取消
保存