onnxruntime/tools/python/util/make_dynamic_shape_fixed.py
Scott McKay 2ca9566994
Add range of helpers for making usage of ORT Mobile easier. (#10458)
* Add range of helpers for making usage of ORT Mobile easier.
2022-02-18 07:35:25 +10:00

57 lines
2.5 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import onnx
import os
import pathlib
import sys
from .onnx_model_utils import make_dim_param_fixed, make_input_shape_fixed, fix_output_shapes
def make_dynamic_shape_fixed_helper():
parser = argparse.ArgumentParser(f'{os.path.basename(__file__)}:{make_dynamic_shape_fixed_helper.__name__}',
description='''
Assign a fixed value to a dim_param or input shape
Provide either dim_param and dim_value or input_name and input_shape.''')
parser.add_argument('--dim_param', type=str, required=False,
help="Symbolic parameter name. Provider dim_value if specified.")
parser.add_argument('--dim_value', type=int, required=False,
help="Value to replace dim_param with in the model. Must be > 0.")
parser.add_argument('--input_name', type=str, required=False,
help="Model input name to replace shape of. Provider input_shape if specified.")
parser.add_argument('--input_shape', type=lambda x: [int(i) for i in x.split(',')], required=False,
help="Shape to use for input_shape. Provide comma separated list for the shape. "
"All values must be > 0. e.g. --input_shape 1,3,256,256")
parser.add_argument('input_model', type=pathlib.Path, help='Provide path to ONNX model to update.')
parser.add_argument('output_model', type=pathlib.Path, help='Provide path to write updated ONNX model to.')
args = parser.parse_args()
if (args.dim_param and args.input_name) or \
(not args.dim_param and not args.input_name) or \
(args.dim_param and (not args.dim_value or args.dim_value < 1)) or \
(args.input_name and (not args.input_shape or any([value < 1 for value in args.input_shape]))):
print('Invalid usage.')
parser.print_help()
sys.exit(-1)
model = onnx.load(str(args.input_model.resolve(strict=True)))
if args.dim_param:
make_dim_param_fixed(model.graph, args.dim_param, args.dim_value)
else:
make_input_shape_fixed(model.graph, args.input_name, args.input_shape)
# update the output shapes to make them fixed if possible.
fix_output_shapes(model)
onnx.save(model, str(args.output_model.resolve()))
if __name__ == '__main__':
make_dynamic_shape_fixed_helper()