浏览代码

[Fix] The summary writer is now implemented in the abtract trainer class. (#806)

Summary writer now displays {}: Step: {}. No episode was completed since last summary. when there was no completed episodes
/develop-generalizationTraining-TrainerController
GitHub 7 年前
当前提交
702d98c6
共有 3 个文件被更改,包括 23 次插入45 次删除
  1. 21
      python/unitytrainers/bc/trainer.py
  2. 23
      python/unitytrainers/ppo/trainer.py
  3. 24
      python/unitytrainers/trainer.py

21
python/unitytrainers/bc/trainer.py


else:
self.stats['losses'].append(0)
def write_summary(self, lesson_number):
"""
Saves training statistics to Tensorboard.
:param lesson_number: The lesson the trainer is at.
"""
if (self.get_step % self.trainer_parameters['summary_freq'] == 0 and self.get_step != 0 and
self.is_training and self.get_step <= self.get_max_steps):
steps = self.get_step
if len(self.stats['cumulative_reward']) > 0:
mean_reward = np.mean(self.stats['cumulative_reward'])
logger.info("{0} : Step: {1}. Mean Reward: {2}. Std of Reward: {3}."
.format(self.brain_name, steps, mean_reward, np.std(self.stats['cumulative_reward'])))
summary = tf.Summary()
for key in self.stats:
if len(self.stats[key]) > 0:
stat_mean = float(np.mean(self.stats[key]))
summary.value.add(tag='Info/{}'.format(key), simple_value=stat_mean)
self.stats[key] = []
summary.value.add(tag='Info/Lesson', simple_value=lesson_number)
self.summary_writer.add_summary(summary, steps)
self.summary_writer.flush()

23
python/unitytrainers/ppo/trainer.py


self.stats['policy_loss'].append(np.mean(total_p))
self.training_buffer.reset_update_buffer()
def write_summary(self, lesson_number):
"""
Saves training statistics to Tensorboard.
:param lesson_number: The lesson the trainer is at.
"""
if (self.get_step % self.trainer_parameters['summary_freq'] == 0 and self.get_step != 0 and
self.is_training and self.get_step <= self.get_max_steps):
steps = self.get_step
if len(self.stats['cumulative_reward']) > 0:
mean_reward = np.mean(self.stats['cumulative_reward'])
logger.info(" {}: Step: {}. Mean Reward: {:0.3f}. Std of Reward: {:0.3f}."
.format(self.brain_name, steps, mean_reward, np.std(self.stats['cumulative_reward'])))
summary = tf.Summary()
for key in self.stats:
if len(self.stats[key]) > 0:
stat_mean = float(np.mean(self.stats[key]))
summary.value.add(tag='Info/{}'.format(key), simple_value=stat_mean)
self.stats[key] = []
summary.value.add(tag='Info/Lesson', simple_value=lesson_number)
self.summary_writer.add_summary(summary, steps)
self.summary_writer.flush()
def discount_rewards(r, gamma=0.99, value_next=0.0):
"""
Computes discounted sum of future rewards for use in updating value estimate.

24
python/unitytrainers/trainer.py


import logging
import tensorflow as tf
import numpy as np
from unityagents import UnityException, AllBrainInfo

self.trainer_parameters = trainer_parameters
self.is_training = training
self.sess = sess
self.stats = {}
self.summary_writer = None
def __str__(self):
return '''Empty Trainer'''

Saves training statistics to Tensorboard.
:param lesson_number: The lesson the trainer is at.
"""
raise UnityTrainerException("The write_summary method was not implemented.")
if (self.get_step % self.trainer_parameters['summary_freq'] == 0 and self.get_step != 0 and
self.is_training and self.get_step <= self.get_max_steps):
steps = self.get_step
if len(self.stats['cumulative_reward']) > 0:
mean_reward = np.mean(self.stats['cumulative_reward'])
logger.info(" {}: Step: {}. Mean Reward: {:0.3f}. Std of Reward: {:0.3f}."
.format(self.brain_name, steps,
mean_reward, np.std(self.stats['cumulative_reward'])))
else:
logger.info(" {}: Step: {}. No episode was completed since last summary."
.format(self.brain_name, steps))
summary = tf.Summary()
for key in self.stats:
if len(self.stats[key]) > 0:
stat_mean = float(np.mean(self.stats[key]))
summary.value.add(tag='Info/{}'.format(key), simple_value=stat_mean)
self.stats[key] = []
summary.value.add(tag='Info/Lesson', simple_value=lesson_number)
self.summary_writer.add_summary(summary, steps)
self.summary_writer.flush()
def write_tensorboard_text(self, key, input_dict):
"""

正在加载...
取消
保存