浏览代码

add special case for non nn.Module load and comment

/fix-resume-imi
Andrew Cohen 3 年前
当前提交
2a5b58dd
共有 1 个文件被更改,包括 23 次插入11 次删除
  1. 34
      ml-agents/mlagents/trainers/model_saver/torch_model_saver.py

34
ml-agents/mlagents/trainers/model_saver/torch_model_saver.py


for name, mod in modules.items():
try:
missing_keys, unexpected_keys = mod.load_state_dict(
saved_state_dict[name], strict=False
)
if missing_keys:
logger.warning(
f"Did not find these keys {missing_keys} in checkpoint. Initializing."
if isinstance(mod, torch.nn.Module):
missing_keys, unexpected_keys = mod.load_state_dict(
saved_state_dict[name], strict=False
if unexpected_keys:
logger.warning(
f"Did not expect these keys {unexpected_keys} in checkpoint. Ignoring."
)
if missing_keys:
logger.warning(
f"Did not find these keys {missing_keys} in checkpoint. Initializing."
)
if unexpected_keys:
logger.warning(
f"Did not expect these keys {unexpected_keys} in checkpoint. Ignoring."
)
else:
# optimizers are treated separately
mod.load_state_dict(saved_state_dict[name])
except (KeyError, TypeError, RuntimeError) as err:
# KeyError is raised if the module was not present in the last run but is being
# accessed in the saved_state_dict.
# ValueError is raised by the optimizer's load_state_dict if the parameters have
# have changed. Note, the optimizer uses a completely different load_state_dict
# function because it is not an nn.Module.
# RuntimeError is raised by PyTorch if there is a size mismatch between modules
# of the same name. This will still partially assign values to those layers that
# have no changed shape.
except (KeyError, ValueError, RuntimeError) as err:
logger.warning(f"Failed to load for module {name}. Initializing")
logger.debug(f"Module loading error : {err}")

正在加载...
取消
保存