|
|
|
|
|
|
def register(self, module_dict): |
|
|
|
pass |
|
|
|
|
|
|
|
def save_checkpoint(self, checkpoint_path: str, brain_name: str) -> None: |
|
|
|
def save_checkpoint(self, brain_name: str, step: int) -> None: |
|
|
|
""" |
|
|
|
Checkpoints the policy on disk. |
|
|
|
|
|
|
|
|
|
|
print('save checkpoint_path:', checkpoint_path) |
|
|
|
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}") |
|
|
|
# Save the TF checkpoint and graph definition |
|
|
|
with self.graph.as_default(): |
|
|
|
if self.saver: |
|
|
|
|
|
|
) |
|
|
|
# also save the policy so we have optimized model files for each checkpoint |
|
|
|
self.export(checkpoint_path, brain_name) |
|
|
|
return checkpoint_path |
|
|
|
|
|
|
|
def export(self, output_filepath: str, brain_name: str) -> None: |
|
|
|
""" |
|
|
|
|
|
|
:param output_filepath: path (without suffix) for the model file(s) |
|
|
|
:param brain_name: Brain name of brain to be trained. |
|
|
|
""" |
|
|
|
print('export output_filepath:', output_filepath) |
|
|
|
export_policy_model(output_filepath, brain_name, self.graph, self.sess) |
|
|
|
|
|
|
|
def maybe_load(self): |
|
|
|
|
|
|
self.policy._initialize_graph() |
|
|
|
|
|
|
|
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None: |
|
|
|
print('load model_path:', model_path) |
|
|
|
with self.graph.as_default(): |
|
|
|
logger.info(f"Loading model from {model_path}.") |
|
|
|
ckpt = tf.train.get_checkpoint_state(model_path) |
|
|
|