浏览代码

Don't save model twice, copy instead (#4302)

* Don't save model twice, copy instead

* narrower exception
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
8128defb
共有 2 个文件被更改,包括 23 次插入3 次删除
  1. 18
      ml-agents/mlagents/model_serialization.py
  2. 8
      ml-agents/mlagents/trainers/trainer/rl_trainer.py

18
ml-agents/mlagents/model_serialization.py


from distutils.util import strtobool
import os
import shutil
from typing import Any, List, Set, NamedTuple
from distutils.version import LooseVersion

return strtobool(val)
except Exception:
return False
def copy_model_files(source_nn_path: str, destination_nn_path: str) -> None:
"""
Copy the .nn file at the given source to the destination.
Also copies the corresponding .onnx file if it exists.
"""
shutil.copyfile(source_nn_path, destination_nn_path)
logger.info(f"Copied {source_nn_path} to {destination_nn_path}.")
# Copy the onnx file if it exists
source_onnx_path = os.path.splitext(source_nn_path)[0] + ".onnx"
destination_onnx_path = os.path.splitext(destination_nn_path)[0] + ".onnx"
try:
shutil.copyfile(source_onnx_path, destination_onnx_path)
logger.info(f"Copied {source_onnx_path} to {destination_onnx_path}.")
except OSError:
pass

8
ml-agents/mlagents/trainers/trainer/rl_trainer.py


import abc
import time
import attr
from mlagents.model_serialization import SerializationSettings
from mlagents.model_serialization import SerializationSettings, copy_model_files
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpoint,
NNCheckpointManager,

"Trainer has multiple policies, but default behavior only saves the first."
)
policy = list(self.policies.values())[0]
settings = SerializationSettings(policy.model_path, self.brain_name)
# Copy the checkpointed model files to the final output location
copy_model_files(model_checkpoint.file_path, f"{policy.model_path}.nn")
policy.save(policy.model_path, settings)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
@abc.abstractmethod

正在加载...
取消
保存