|
|
|
|
|
|
except Exception: |
|
|
|
if name in saved_state_dict: |
|
|
|
logger.warning(f"Failed to load directly for module {name}.") |
|
|
|
if isinstance(mod, torch.optim.Adam): |
|
|
|
logger.warning(f"Ignoring state for optimizer {name}.") |
|
|
|
|
|
|
|
else: |
|
|
|
for mod_element in mod.state_dict(): |
|
|
|
if mod_element in saved_state_dict[name]: |
|
|
|
logger.warning(f"Copying {mod_element}") |
|
|
|
mod.state_dict()[mod_element].copy_( |
|
|
|
saved_state_dict[name][mod_element] |
|
|
|
) |
|
|
|
else: |
|
|
|
logger.warning( |
|
|
|
f"{mod_element} was not found in the saved module {name}. This will be initialized." |
|
|
|
) |
|
|
|
for mod_element in mod.state_dict(): |
|
|
|
try: |
|
|
|
logger.warning(f"Copying {mod_element}") |
|
|
|
mod.state_dict()[mod_element].copy_( |
|
|
|
saved_state_dict[name][mod_element] |
|
|
|
) |
|
|
|
except Exception: |
|
|
|
logger.warning( |
|
|
|
f"{mod_element} was not found or is not loadable (changed shape). Initializing." |
|
|
|
) |
|
|
|
f"The module {name} was not found in the checkpoint. This will be initialized." |
|
|
|
f"The module {name} was not found in the checkpoint. Initializing." |
|
|
|
) |
|
|
|
|
|
|
|
if reset_global_steps: |
|
|
|