|
|
|
|
|
|
|
|
|
|
for name, mod in modules.items(): |
|
|
|
try: |
|
|
|
missing_keys, _ = mod.load_state_dict( |
|
|
|
missing_keys, unexpected_keys = mod.load_state_dict( |
|
|
|
saved_state_dict[name], strict=False |
|
|
|
) |
|
|
|
if missing_keys: |
|
|
|
|
|
|
if unexpected_keys: |
|
|
|
logger.warning( |
|
|
|
f"Did not expect these keys {unexpected_keys} in checkpoint. Initializing" |
|
|
|
) |
|
|
|
|
|
|
|
except (KeyError, TypeError): |
|
|
|
logger.warning(f"Failed to load for module {name}. Initializing") |
|
|
|
|
|
|
|