|
|
|
|
|
|
policy = cast(TorchPolicy, policy) |
|
|
|
|
|
|
|
for name, mod in modules.items(): |
|
|
|
mod.load_state_dict(saved_state_dict[name]) |
|
|
|
try: |
|
|
|
mod.load_state_dict(saved_state_dict[name]) |
|
|
|
except Exception: |
|
|
|
if name in saved_state_dict: |
|
|
|
logger.warning(f"Failed to load directly for module {name}.") |
|
|
|
if isinstance(mod, torch.optim.Adam): |
|
|
|
logger.warning(f"Ignoring state for optimizer {name}.") |
|
|
|
|
|
|
|
else: |
|
|
|
for mod_element in mod.state_dict(): |
|
|
|
if mod_element in saved_state_dict[name]: |
|
|
|
logger.warning(f"Copying {mod_element}") |
|
|
|
mod.state_dict()[mod_element].copy_( |
|
|
|
saved_state_dict[name][mod_element] |
|
|
|
) |
|
|
|
else: |
|
|
|
logger.warning( |
|
|
|
f"{mod_element} was not found in the saved module {name}. This will be initialized." |
|
|
|
) |
|
|
|
else: |
|
|
|
logger.warning( |
|
|
|
f"The module {name} was not found in the checkpoint. This will be initialized." |
|
|
|
) |
|
|
|
|
|
|
|
if reset_global_steps: |
|
|
|
policy.set_step(0) |
|
|
|