|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
from pympler import muppy, summary |
|
|
|
|
|
|
|
|
|
|
|
class RLTrainer(Trainer): # pylint: disable=abstract-method |
|
|
|
""" |
|
|
|
This class is the base class for trainers that use Reward Signals. |
|
|
|
|
|
|
self._next_summary_step = self._get_next_interval_step(self.summary_freq) |
|
|
|
if step_after_process >= self._next_summary_step and self.get_step != 0: |
|
|
|
self._write_summary(self._next_summary_step) |
|
|
|
from guppy import hpy; h = hpy(); print(h.heap()) |
|
|
|
all_objects = muppy.get_objects() |
|
|
|
sum1 = summary.summarize(all_objects) |
|
|
|
summary.print_(sum1) |
|
|
|
|
|
|
|
def _maybe_save_model(self, step_after_process: int) -> None: |
|
|
|
""" |
|
|
|
|
|
|
self.trainer_settings.checkpoint_interval |
|
|
|
) |
|
|
|
if step_after_process >= self._next_save_step and self.get_step != 0: |
|
|
|
from guppy import hpy; h = hpy(); print(h.heap()) |
|
|
|
self._checkpoint() |
|
|
|
|
|
|
|
def advance(self) -> None: |
|
|
|