|
|
|
|
|
|
|
|
|
|
for name, mod in modules.items(): |
|
|
|
try: |
|
|
|
mod.load_state_dict(saved_state_dict[name]) |
|
|
|
except Exception: |
|
|
|
if name in saved_state_dict: |
|
|
|
logger.warning(f"Failed to load directly for module {name}.") |
|
|
|
for mod_element in mod.state_dict(): |
|
|
|
try: |
|
|
|
logger.debug(f"Copying {mod_element}") |
|
|
|
mod.state_dict()[mod_element].copy_( |
|
|
|
saved_state_dict[name][mod_element] |
|
|
|
) |
|
|
|
except Exception: |
|
|
|
logger.debug( |
|
|
|
f"{mod_element} was not found or is not loadable (changed shape). Initializing." |
|
|
|
) |
|
|
|
else: |
|
|
|
missing_keys, _ = mod.load_state_dict( |
|
|
|
saved_state_dict[name], strict=False |
|
|
|
) |
|
|
|
if missing_keys: |
|
|
|
f"The module {name} was not found in the checkpoint. Initializing." |
|
|
|
f"Did not find these keys {missing_keys} in checkpoint. Initializing" |
|
|
|
except (KeyError, TypeError): |
|
|
|
logger.warning(f"Failed to load for module {name}. Initializing") |
|
|
|
|
|
|
|
if reset_global_steps: |
|
|
|
policy.set_step(0) |
|
|
|