浏览代码

load individual elements if state dict load fails

/fix-resume-imi
Andrew Cohen 4 年前
当前提交
d6802cf1
共有 1 个文件被更改,包括 23 次插入1 次删除
  1. 24
      ml-agents/mlagents/trainers/model_saver/torch_model_saver.py

24
ml-agents/mlagents/trainers/model_saver/torch_model_saver.py


policy = cast(TorchPolicy, policy)
for name, mod in modules.items():
mod.load_state_dict(saved_state_dict[name])
try:
mod.load_state_dict(saved_state_dict[name])
except Exception:
if name in saved_state_dict:
logger.warning(f"Failed to load directly for module {name}.")
if isinstance(mod, torch.optim.Adam):
logger.warning(f"Ignoring state for optimizer {name}.")
else:
for mod_element in mod.state_dict():
if mod_element in saved_state_dict[name]:
logger.warning(f"Copying {mod_element}")
mod.state_dict()[mod_element].copy_(
saved_state_dict[name][mod_element]
)
else:
logger.warning(
f"{mod_element} was not found in the saved module {name}. This will be initialized."
)
else:
logger.warning(
f"The module {name} was not found in the checkpoint. This will be initialized."
)
if reset_global_steps:
policy.set_step(0)

正在加载...
取消
保存