浏览代码

'clean up' for Scott

/sensitivity
Andrew Cohen 4 年前
当前提交
e55ecd61
共有 1 个文件被更改,包括 8 次插入35 次删除
  1. 43
      ml-agents/mlagents/trainers/stats.py

43
ml-agents/mlagents/trainers/stats.py


self._maybe_create_summary_writer(category)
# adapted from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
counts, bin_edges = np.histogram(value, bins=len(value))
#print(float(np.min(value)))
#print(float(np.max(value)))
#print(int(np.prod(value.shape)))
#print(float(np.sum(value)))
#print(float(np.sum(value**2)))
#print(bin_edges[1:])
#hist.min = float(np.min(value))
#hist.max = float(np.max(value))
#hist.num = int(np.prod(value.shape))
#hist.sum = float(np.sum(value))
#hist.sum_squares = float(np.sum(value**2))
#value = np.log(value)
for obs, grad in sorted(enumerate(value), reverse=True, key=lambda x: x[1]):
print(f"Observation {obs} has relevance {grad}")
# for obs, grad in sorted(enumerate(value), reverse=True, key=lambda x: x[1]):
# print(f"Observation {obs} has relevance {grad}")
hist.max = float(len(value))#float(np.max(value))
hist.num = len(value)#int(np.prod(value.shape))
hist.max = float(len(value))
hist.num = len(value)
hist.sum_squares = float(np.sum(value**2))
hist.sum_squares = float(np.sum(value ** 2))
for edge in range(len(value)):#counts:
#print(edge)
hist.bucket_limit.append(edge+.5)
for edge in range(len(value)):
hist.bucket_limit.append(edge + 0.5)
#print(c)
# Add bin edges and counts
# for edge,i in zip(range(1,len(value)), bin_edges):
# hist.bucket_limit.append(i)
# for c,i in zip(value, counts):
# hist.bucket.append(i)
#summary = tf.Summary()
#summary.value.add(tag="Saliency", histo=hist)
#self.summary_writers[category].add_summary(summary)#, self.trajectories)
#self.summary_writers[category].flush()
#with tf.Session(config=generate_session_config()) as sess:
# hist_op = tf.summary.histogram(category, value)
# hist = sess.run(hist_op)
def _dict_to_tensorboard(
self, name: str, input_dict: Dict[str, Any]
) -> Optional[bytes]:

正在加载...
取消
保存