|
|
|
|
|
|
:param checkpoint_path: filepath to write the checkpoint |
|
|
|
:param brain_name: Brain name of brain to be trained |
|
|
|
""" |
|
|
|
print('save checkpoint_path:', checkpoint_path) |
|
|
|
if not os.path.exists(self.model_path): |
|
|
|
os.makedirs(self.model_path) |
|
|
|
state_dict = {name: module.state_dict() for name, module in self.modules.items()} |
|
|
|
|
|
|
self._load_model(self.model_path, reset_global_steps=reset_steps) |
|
|
|
|
|
|
|
def export(self, output_filepath: str, brain_name: str) -> None: |
|
|
|
print('export output_filepath:', output_filepath) |
|
|
|
print('load model_path:', model_path) |
|
|
|
saved_state_dict = torch.load(model_path) |
|
|
|
for name, state_dict in saved_state_dict.items(): |
|
|
|
self.modules[name].load_state_dict(state_dict) |
|
|
|