mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
63 lines
2.7 KiB
Python
63 lines
2.7 KiB
Python
|
|
#!/usr/bin/env python
|
||
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||
|
|
# Licensed under the MIT License.
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import onnxruntime as ort
|
||
|
|
import os
|
||
|
|
import re
|
||
|
|
|
||
|
|
|
||
|
|
def convert(model: str):
|
||
|
|
|
||
|
|
if not model.endswith('.onnx'):
|
||
|
|
raise ValueError("Model filename must end in .onnx.")
|
||
|
|
|
||
|
|
onnx_target_path = re.sub('.onnx$', '.optimized.onnx', model)
|
||
|
|
ort_target_path = re.sub('.onnx$', '.ort', model)
|
||
|
|
|
||
|
|
so = ort.SessionOptions()
|
||
|
|
so.optimized_model_filepath = onnx_target_path
|
||
|
|
so.add_session_config_entry('session.save_model_format', 'ONNX')
|
||
|
|
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED # Skip NCHWc optimizations
|
||
|
|
|
||
|
|
print("Optimizing ONNX model {} and saving in ONNX format to {}".format(model, onnx_target_path))
|
||
|
|
# creating the session will result in the optimized model being saved
|
||
|
|
_ = ort.InferenceSession(model, sess_options=so)
|
||
|
|
|
||
|
|
# Second, convert optimized ONNX model to ORT format
|
||
|
|
so.optimized_model_filepath = ort_target_path
|
||
|
|
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL # Convert model as-is so we don't change the kernels in this step # noqa
|
||
|
|
|
||
|
|
so.add_session_config_entry('session.save_model_format', 'ORT')
|
||
|
|
|
||
|
|
print("Converting optimized ONNX model {} to ORT format model {}".format(onnx_target_path, ort_target_path))
|
||
|
|
_ = ort.InferenceSession(onnx_target_path, sess_options=so)
|
||
|
|
|
||
|
|
orig_size = os.path.getsize(onnx_target_path)
|
||
|
|
new_size = os.path.getsize(ort_target_path)
|
||
|
|
print("Serialized {} to {}. Sizes: orig={} new={} diff={} new:old={:.4f}:1.0".format(
|
||
|
|
onnx_target_path, ort_target_path, orig_size, new_size, new_size - orig_size, new_size / orig_size))
|
||
|
|
|
||
|
|
|
||
|
|
def parse_args():
|
||
|
|
parser = argparse.ArgumentParser(os.path.basename(__file__),
|
||
|
|
description='''Convert an onnx model -> optimized onnx model -> ORT format model.
|
||
|
|
Expects a .onnx file as input. Optimized onnx model will be saved in the same
|
||
|
|
directory with an extension of .optimized.onnx.
|
||
|
|
An ORT format model will be created from the optimized onnx model.
|
||
|
|
The optimized onnx model should be used as input to a minimal build so that
|
||
|
|
any post-optimization kernels are included in the build.'''
|
||
|
|
)
|
||
|
|
parser.add_argument('model', help='Provide path to ONNX model to convert. Must have .onnx extension.')
|
||
|
|
return parser.parse_args()
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
args = parse_args()
|
||
|
|
convert(args.model)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|