|
|
|
|
|
|
from mlagents_envs.logging_util import get_logger |
|
|
|
from mlagents.trainers.settings import SerializationSettings |
|
|
|
|
|
|
|
from IPython import embed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
POSSIBLE_INPUT_NODES = frozenset( |
|
|
|
[ |
|
|
|
"action_masks", |
|
|
|
"prev_action", |
|
|
|
"recurrent_in", |
|
|
|
"sequence_length", |
|
|
|
"vector_observation", |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
POSSIBLE_OUTPUT_NODES = frozenset( |
|
|
|
["action", "action_probs", "recurrent_out", "value_estimate"] |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class ModelSerializer: |
|
|
|
|
|
|
""" |
|
|
|
if not os.path.exists(output_filepath): |
|
|
|
os.makedirs(output_filepath) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
onnx_output_path = f"{output_filepath}.onnx" |
|
|
|
logger.info(f"Converting to {onnx_output_path}") |
|
|
|