浏览代码

small improvement

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
09a741c8
共有 4 个文件被更改,包括 29 次插入12 次删除
  1. 8
      ml-agents/mlagents/trainers/ppo/trainer.py
  2. 8
      ml-agents/mlagents/trainers/sac/trainer.py
  3. 1
      ml-agents/mlagents/trainers/torch/model_serialization.py
  4. 24
      ml-agents/mlagents/trainers/trainer/rl_trainer.py

8
ml-agents/mlagents/trainers/ppo/trainer.py


self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
if self.saver is None:
self.saver = self.create_saver(policy=policy)
self.saver = self.create_saver(
self.framework,
policy,
self.trainer_settings,
self.artifact_path,
self.load,
)
self.saver.register(self.policy)
self.saver.register(self.optimizer)
self.saver.maybe_load()

8
ml-agents/mlagents/trainers/sac/trainer.py


self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
if self.saver is None:
self.saver = self.create_saver(policy=policy)
self.saver = self.create_saver(
self.framework,
policy,
self.trainer_settings,
self.artifact_path,
self.load,
)
self.saver.register(self.policy)
self.saver.register(self.optimizer)
self.saver.maybe_load()

1
ml-agents/mlagents/trainers/torch/model_serialization.py


}
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories)
@staticmethod
def export_policy_model(self, output_filepath: str) -> None:
"""
Exports a Torch model for a Policy to .onnx format for Unity embedding.

24
ml-agents/mlagents/trainers/trainer/rl_trainer.py


from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.settings import TestingConfiguration
from mlagents.trainers.settings import TestingConfiguration, TrainerSettings
from mlagents.trainers.stats import StatsPropertyType
from mlagents.trainers.saver.saver import BaseSaver
from mlagents.trainers.saver.torch_saver import TorchSaver

pass
@staticmethod
def create_saver(self, policy: Policy) -> BaseSaver:
if self.framework == "torch":
def create_saver(
framework: str,
policy: Policy,
trainer_settings: TrainerSettings,
model_path: str,
load: bool,
) -> BaseSaver:
if framework == "torch":
self.trainer_settings,
model_path=self.artifact_path,
load=self.load,
trainer_settings,
model_path,
load,
self.trainer_settings,
model_path=self.artifact_path,
load=self.load,
trainer_settings,
model_path,
load,
)
return saver

正在加载...
取消
保存