|
|
|
|
|
|
from mlagents.trainers.trajectory import Trajectory |
|
|
|
from mlagents.trainers.settings import TestingConfiguration |
|
|
|
from mlagents.trainers.stats import StatsPropertyType |
|
|
|
from mlagents.trainers.saver.saver import Saver |
|
|
|
from mlagents.trainers.saver.saver import BaseSaver |
|
|
|
from mlagents.trainers.saver.torch_saver import TorchSaver |
|
|
|
from mlagents.trainers.saver.tf_saver import TFSaver |
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
def create_saver(self, policy: Policy) -> Saver: |
|
|
|
def create_saver(self, policy: Policy) -> BaseSaver: |
|
|
|
if self.framework == "torch": |
|
|
|
saver = TorchSaver( |
|
|
|
policy, |
|
|
|