|
|
|
|
|
|
|
|
|
|
for name, mod in modules.items(): |
|
|
|
try: |
|
|
|
missing_keys, unexpected_keys = mod.load_state_dict( |
|
|
|
saved_state_dict[name], strict=False |
|
|
|
) |
|
|
|
if missing_keys: |
|
|
|
logger.warning( |
|
|
|
f"Did not find these keys {missing_keys} in checkpoint. Initializing." |
|
|
|
if isinstance(mod, torch.nn.Module): |
|
|
|
missing_keys, unexpected_keys = mod.load_state_dict( |
|
|
|
saved_state_dict[name], strict=False |
|
|
|
if unexpected_keys: |
|
|
|
logger.warning( |
|
|
|
f"Did not expect these keys {unexpected_keys} in checkpoint. Ignoring." |
|
|
|
) |
|
|
|
if missing_keys: |
|
|
|
logger.warning( |
|
|
|
f"Did not find these keys {missing_keys} in checkpoint. Initializing." |
|
|
|
) |
|
|
|
if unexpected_keys: |
|
|
|
logger.warning( |
|
|
|
f"Did not expect these keys {unexpected_keys} in checkpoint. Ignoring." |
|
|
|
) |
|
|
|
else: |
|
|
|
# optimizers are treated separately |
|
|
|
mod.load_state_dict(saved_state_dict[name]) |
|
|
|
except (KeyError, TypeError, RuntimeError) as err: |
|
|
|
# KeyError is raised if the module was not present in the last run but is being |
|
|
|
# accessed in the saved_state_dict. |
|
|
|
# ValueError is raised by the optimizer's load_state_dict if the parameters have |
|
|
|
# have changed. Note, the optimizer uses a completely different load_state_dict |
|
|
|
# function because it is not an nn.Module. |
|
|
|
# RuntimeError is raised by PyTorch if there is a size mismatch between modules |
|
|
|
# of the same name. This will still partially assign values to those layers that |
|
|
|
# have no changed shape. |
|
|
|
except (KeyError, ValueError, RuntimeError) as err: |
|
|
|
logger.warning(f"Failed to load for module {name}. Initializing") |
|
|
|
logger.debug(f"Module loading error : {err}") |
|
|
|
|
|
|
|