浏览代码

use load_state_dict strict = False

/fix-resume-imi
Andrew Cohen 3 年前
当前提交
1283e055
共有 1 个文件被更改,包括 7 次插入16 次删除
  1. 23
      ml-agents/mlagents/trainers/model_saver/torch_model_saver.py

23
ml-agents/mlagents/trainers/model_saver/torch_model_saver.py


for name, mod in modules.items():
try:
mod.load_state_dict(saved_state_dict[name])
except Exception:
if name in saved_state_dict:
logger.warning(f"Failed to load directly for module {name}.")
for mod_element in mod.state_dict():
try:
logger.debug(f"Copying {mod_element}")
mod.state_dict()[mod_element].copy_(
saved_state_dict[name][mod_element]
)
except Exception:
logger.debug(
f"{mod_element} was not found or is not loadable (changed shape). Initializing."
)
else:
missing_keys, _ = mod.load_state_dict(
saved_state_dict[name], strict=False
)
if missing_keys:
f"The module {name} was not found in the checkpoint. Initializing."
f"Did not find these keys {missing_keys} in checkpoint. Initializing"
except (KeyError, TypeError):
logger.warning(f"Failed to load for module {name}. Initializing")
if reset_global_steps:
policy.set_step(0)

正在加载...
取消
保存