Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

53 行
2.1 KiB

import os
from typing import Any, List, Set, NamedTuple
import torch
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.settings import SerializationSettings
logger = get_logger(__name__)
class ModelSerializer:
def __init__(self, policy):
self.policy = policy
dummy_vec_obs = [torch.zeros([1] + [self.policy.vec_obs_size])]
dummy_vis_obs = [torch.zeros([1] + self.policy.vis_obs_shape)] \
if self.policy.vis_obs_size > 0 else []
dummy_masks = [torch.ones([1] + self.policy.actor_critic.act_size)]
dummy_memories = [torch.zeros([1] + [self.policy.m_size])]
dummy_sequence_length = [torch.tensor([self.policy.sequence_length])]
self.input_names = ["vector_observation", "visual_observation", \
"action_mask", "memories", "sequence_length"]
self.output_names = ["action", "action_probs", "version_number", \
"memory_size", "is_continuous_control", "action_output_shape"]
self.dynamic_axes = {"vector_observation": [0], "visual_observation": [0], \
"action_mask": [0], "memories": [0], "action": [0],"action_probs": [0]}
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, \
dummy_masks, dummy_memories, dummy_sequence_length)
def export_policy_model(self, output_filepath: str) -> None:
"""
Exports a Torch model for a Policy to .onnx format for Unity embedding.
:param output_filepath: file path to output the model (without file suffix)
:param brain_name: Brain name of brain to be trained
"""
if not os.path.exists(output_filepath):
os.makedirs(output_filepath)
onnx_output_path = f"{output_filepath}.onnx"
logger.info(f"Converting to {onnx_output_path}")
torch.onnx.export(
self.policy.actor_critic,
self.dummy_input,
onnx_output_path,
verbose=True,
opset_version=SerializationSettings.onnx_opset,
input_names=self.input_names,
output_names=self.output_names,
dynamic_axes=self.dynamic_axes,
)