|
|
|
|
|
|
with self.policy.graph.as_default(): |
|
|
|
self.tf_saver = tf.train.Saver(max_to_keep=self._keep_checkpoints) |
|
|
|
|
|
|
|
def save_checkpoint(self, behavior_name: str, step: int) -> str: |
|
|
|
checkpoint_path = os.path.join(self.model_path, f"{behavior_name}-{step}") |
|
|
|
def save_checkpoint(self, brain_name: str, step: int) -> str: |
|
|
|
checkpoint_path = os.path.join(self.model_path, f"{brain_name}-{step}") |
|
|
|
# Save the TF checkpoint and graph definition |
|
|
|
if self.graph: |
|
|
|
with self.graph.as_default(): |
|
|
|
|
|
|
self.graph, self.model_path, "raw_graph_def.pb", as_text=False |
|
|
|
) |
|
|
|
# also save the policy so we have optimized model files for each checkpoint |
|
|
|
self.export(checkpoint_path, behavior_name) |
|
|
|
self.export(checkpoint_path, brain_name) |
|
|
|
def export(self, output_filepath: str, behavior_name: str) -> None: |
|
|
|
def export(self, output_filepath: str, brain_name: str) -> None: |
|
|
|
self.model_path, output_filepath, behavior_name, self.graph, self.sess |
|
|
|
self.model_path, output_filepath, brain_name, self.graph, self.sess |
|
|
|
) |
|
|
|
|
|
|
|
def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None: |
|
|
|
|
|
|
self._load_graph(policy, self.model_path, reset_global_steps=reset_steps) |
|
|
|
else: |
|
|
|
policy.initialize() |
|
|
|
|
|
|
|
TFPolicy.broadcast_global_variables(0) |
|
|
|
|
|
|
|
def _load_graph( |
|
|
|