浏览代码

small improvements

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
bdb2ba93
共有 5 个文件被更改,包括 28 次插入39 次删除
  1. 2
      ml-agents/mlagents/trainers/saver/saver.py
  2. 6
      ml-agents/mlagents/trainers/saver/tf_saver.py
  3. 6
      ml-agents/mlagents/trainers/saver/torch_saver.py
  4. 14
      ml-agents/mlagents/trainers/torch/networks.py
  5. 39
      ml-agents/mlagents/trainers/trainer/rl_trainer.py

2
ml-agents/mlagents/trainers/saver/saver.py


import abc
class Saver(abc.ABC):
class BaseSaver(abc.ABC):
"""This class is the base class for the Saver"""
def __init__(self):

6
ml-agents/mlagents/trainers/saver/tf_saver.py


from mlagents_envs.exception import UnityException
from mlagents_envs.logging_util import get_logger
from mlagents.tf_utils import tf
from mlagents.trainers.saver.saver import Saver
from mlagents.trainers.saver.saver import BaseSaver
from mlagents.trainers.tf.model_serialization import export_policy_model
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.settings import TrainerSettings

logger = get_logger(__name__)
class TFSaver(Saver):
class TFSaver(BaseSaver):
"""
Saver class for TensorFlow
"""

self._keep_checkpoints = trainer_settings.keep_checkpoints
self.load = load
self.graph = self.policy.graph
self.sess = self.policy.sess
with self.graph.as_default():

6
ml-agents/mlagents/trainers/saver/torch_saver.py


import torch
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.saver.saver import Saver
from mlagents.trainers.saver.saver import BaseSaver
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.policy.torch_policy import TorchPolicy

logger = get_logger(__name__)
class TorchSaver(Saver):
class TorchSaver(BaseSaver):
"""
Saver class for PyTorch
"""

def register(self, module):
self.modules.update(module.get_modules())
def save_checkpoint(self, checkpoint_path: str, brain_name: str) -> None:
"""
Checkpoints the policy on disk.

14
ml-agents/mlagents/trainers/torch/networks.py


class GlobalSteps(nn.Module):
def __init__(self):
super().__init__()
self.global_step = nn.Parameter(torch.Tensor([0]), requires_grad=False)
self._global_step = nn.Parameter(torch.Tensor([0]), requires_grad=False)
def increment(self, value):
self.global_step += value
@property
def step(self):
return int(self._global_step.item())
@step.setter
self.global_step[:] = value
self._global_step.data = value
def get_step(self):
return int(self.global_step.item())
def increment(self, value):
self._global_step += value
class LearningRate(nn.Module):

39
ml-agents/mlagents/trainers/trainer/rl_trainer.py


Create a Policy object that uses the TensorFlow backend.
"""
pass
return self.create_torch_saver(policy)
saver = TorchSaver(
policy,
self.trainer_settings,
model_path=self.artifact_path,
load=self.load,
)
return self.create_tf_saver(policy)
def create_torch_saver(self, policy: TorchPolicy) -> TorchSaver:
"""
Create a Saver object that uses the PyTorch backend.
"""
saver = TorchSaver(
policy,
self.trainer_settings,
model_path=self.artifact_path,
load=self.load,
)
return saver
def create_tf_saver(self, policy: TFPolicy) -> TFSaver:
"""
Create a Saver object that uses the TensorFlow backend.
"""
saver = TFSaver(
policy,
self.trainer_settings,
model_path=self.artifact_path,
load=self.load,
)
saver = TFSaver(
policy,
self.trainer_settings,
model_path=self.artifact_path,
load=self.load,
)
return saver
def _policy_mean_reward(self) -> Optional[float]:

正在加载...
取消
保存