|
|
|
|
|
|
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers |
|
|
|
from mlagents.trainers.agent_processor import AgentManagerQueue |
|
|
|
from mlagents.trainers.trajectory import Trajectory |
|
|
|
from mlagents.trainers.settings import TestingConfiguration |
|
|
|
from mlagents.trainers.settings import TestingConfiguration, TrainerSettings |
|
|
|
from mlagents.trainers.stats import StatsPropertyType |
|
|
|
from mlagents.trainers.saver.saver import BaseSaver |
|
|
|
from mlagents.trainers.saver.torch_saver import TorchSaver |
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def create_saver(self, policy: Policy) -> BaseSaver: |
|
|
|
if self.framework == "torch": |
|
|
|
def create_saver( |
|
|
|
framework: str, |
|
|
|
policy: Policy, |
|
|
|
trainer_settings: TrainerSettings, |
|
|
|
model_path: str, |
|
|
|
load: bool, |
|
|
|
) -> BaseSaver: |
|
|
|
if framework == "torch": |
|
|
|
self.trainer_settings, |
|
|
|
model_path=self.artifact_path, |
|
|
|
load=self.load, |
|
|
|
trainer_settings, |
|
|
|
model_path, |
|
|
|
load, |
|
|
|
self.trainer_settings, |
|
|
|
model_path=self.artifact_path, |
|
|
|
load=self.load, |
|
|
|
trainer_settings, |
|
|
|
model_path, |
|
|
|
load, |
|
|
|
) |
|
|
|
return saver |
|
|
|
|
|
|
|