浏览代码

fix export input names

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
f40996e2
共有 3 个文件被更改,包括 32 次插入31 次删除
  1. 17
      ml-agents/mlagents/trainers/saver/torch_saver.py
  2. 44
      ml-agents/mlagents/trainers/torch/model_serialization.py
  3. 2
      ml-agents/mlagents/trainers/trainer/rl_trainer.py

17
ml-agents/mlagents/trainers/saver/torch_saver.py


import shutil
import torch
from typing import Dict, Union, Optional, cast
from mlagents_envs.exception import UnityPolicyException
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.saver.saver import BaseSaver
from mlagents.trainers.settings import TrainerSettings, SerializationSettings

self.modules: Dict[str, torch.nn.Modules] = {}
def register(self, module: Union[TorchPolicy, TorchOptimizer]) -> None:
self.modules.update(module.get_modules()) # type: ignore
if isinstance(module, TorchPolicy) or isinstance(module, TorchOptimizer):
self.modules.update(module.get_modules()) # type: ignore
else:
raise UnityPolicyException(
"Registering Object of unsupported type {} to Saver ".format(
type(module)
)
)
if self.policy is None and isinstance(module, TorchPolicy):
self.policy = module
self.exporter = ModelSerializer(self.policy)

# Initialize/Load registered self.policy by default.
# If given input argument policy, use the input policy instead.
# This argument is mainly for initialization of the ghost trainer's fixed policy.
reset_steps = not self.load
if self.initialize_path is not None:
self._load_model(

Also copies the corresponding .onnx file if it exists.
"""
final_model_name = os.path.splitext(source_nn_path)[0]
if SerializationSettings.convert_to_barracuda:
source_path = f"{final_model_name}.nn"
destination_path = f"{self.model_path}.nn"
shutil.copyfile(source_path, destination_path)
logger.info(f"Copied {source_path} to {destination_path}.")
if SerializationSettings.convert_to_onnx:
try:

44
ml-agents/mlagents/trainers/torch/model_serialization.py


class ModelSerializer:
def __init__(self, policy):
self.policy = policy
# dimension for batch (and sequence_length if use recurrent)
dummy_dim = [1, 1] if self.policy.use_recurrent else [1]
dummy_vec_obs = [torch.zeros(dummy_dim + [self.policy.vec_obs_size])]
batch_dim = [1]
dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])]
[torch.zeros(dummy_dim + list(self.policy.vis_obs_shape))]
[torch.zeros(batch_dim + list(self.policy.vis_obs_shape))]
dummy_masks = torch.ones([1] + self.policy.actor_critic.act_size)
dummy_memories = torch.zeros(dummy_dim + [self.policy.m_size])
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)])
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size])
self.input_names = [
"vector_observation",
"visual_observation",
"action_mask",
"memories",
]
self.input_names = []
self.dynamic_axes = {"action": {0: 'batch'}, "action_probs": {0: 'batch'}}
if self.policy.use_vec_obs:
self.input_names.append("vector_observation")
self.dynamic_axes.update({"vector_observation": {0: 'batch'}})
if self.policy.use_vis_obs:
self.input_names.append("visual_observation")
self.dynamic_axes.update({"visual_observation": {0: 'batch'}})
if not self.policy.use_continuous_act:
self.input_names.append("action_mask")
self.dynamic_axes.update({"action_mask": {0: 'batch'}})
if self.policy.use_recurrent:
self.input_names.append("memories")
self.dynamic_axes.update({"memories": {0: 'batch'}})
self.output_names = [
"action",
"action_probs",

"action_output_shape",
]
self.dynamic_axes = {
"vector_observation": [0],
"visual_observation": [0],
"action_mask": [0],
"memories": [0],
"action": [0],
"action_probs": [0],
}
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories)
def export_policy_model(self, output_filepath: str) -> None:

self.policy.actor_critic,
self.dummy_input,
onnx_output_path,
verbose=True,
verbose=False,
opset_version=SerializationSettings.onnx_opset,
input_names=self.input_names,
output_names=self.output_names,

2
ml-agents/mlagents/trainers/trainer/rl_trainer.py


def create_saver(
framework: str, trainer_settings: TrainerSettings, model_path: str, load: bool
) -> BaseSaver:
if framework == "torch":
if framework == FrameworkType.PYTORCH:
saver = TorchSaver( # type: ignore
trainer_settings, model_path, load
)

正在加载...
取消
保存