mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-27 03:11:28 +00:00
Update ReformatSourcePython.bat to use YAPF to format python code, and add onnxruntime\test directory to be formatted. Add onnxruntime\.style.yapf for configuration. The style is based on google, except max column width 120. Format python scripts using ReformatSourcePython.bat.
100 lines
4 KiB
Python
100 lines
4 KiB
Python
#-------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
#--------------------------------------------------------------------------
|
|
|
|
import argparse
|
|
import onnxruntime as onnxrt
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
from timeit import default_timer as timer
|
|
|
|
float_dict = {'tensor(float16)': 'float16', 'tensor(float)': 'float32', 'tensor(double)': 'float64'}
|
|
|
|
integer_dict = {
|
|
'tensor(int32)': 'int32',
|
|
'tensor(int8)': 'int8',
|
|
'tensor(uint8)': 'uint8',
|
|
'tensor(int16)': 'int16',
|
|
'tensor(uint16)': 'uint16',
|
|
'tensor(int64)': 'int64',
|
|
'tensor(uint64)': 'uint64'
|
|
}
|
|
|
|
# simple test program for loading onnx model, feeding all inputs and running the model num_iters times.
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Simple ONNX Runtime Test Tool.')
|
|
parser.add_argument('model_path', help='model path')
|
|
parser.add_argument('num_iters', nargs='?', type=int, default=1000, help='model run iterations. default=1000')
|
|
parser.add_argument('--debug', action='store_true', help='pause execution to allow attaching a debugger.')
|
|
parser.add_argument('--profile', action='store_true', help='enable chrome timeline trace profiling.')
|
|
args = parser.parse_args()
|
|
iters = args.num_iters
|
|
|
|
if args.debug:
|
|
print("Pausing execution ready for debugger to attach to pid: {}".format(os.getpid()))
|
|
print("Press key to continue.")
|
|
sys.stdin.read(1)
|
|
|
|
sess_options = None
|
|
if args.profile:
|
|
sess_options = onnxrt.SessionOptions()
|
|
sess_options.enable_profiling = True
|
|
sess_options.profile_file_prefix = os.path.basename(args.model_path)
|
|
|
|
sess = onnxrt.InferenceSession(args.model_path, sess_options)
|
|
meta = sess.get_modelmeta()
|
|
|
|
feeds = {}
|
|
for input_meta in sess.get_inputs():
|
|
# replace any symbolic dimensions (value is None) with 1
|
|
shape = [dim if dim else 1 for dim in input_meta.shape]
|
|
if input_meta.type in float_dict:
|
|
feeds[input_meta.name] = np.random.rand(*shape).astype(float_dict[input_meta.type])
|
|
elif input_meta.type in integer_dict:
|
|
feeds[input_meta.name] = np.random.uniform(high=1000,
|
|
size=tuple(shape)).astype(integer_dict[input_meta.type])
|
|
elif input_meta.type == 'tensor(bool)':
|
|
feeds[input_meta.name] = np.random.randint(2, size=tuple(shape)).astype('bool')
|
|
else:
|
|
print("unsupported input type {} for input {}".format(input_meta.type, input_meta.name))
|
|
sys.exit(-1)
|
|
|
|
# Starting with IR4 some initializers provide default values
|
|
# and can be overridden (available in IR4). For IR < 4 models
|
|
# the list would be empty
|
|
for initializer in sess.get_overridable_initializers():
|
|
shape = [dim if dim else 1 for dim in initializer.shape]
|
|
if initializer.type in float_dict:
|
|
feeds[initializer.name] = np.random.rand(*shape).astype(float_dict[initializer.type])
|
|
elif initializer.type in integer_dict:
|
|
feeds[initializer.name] = np.random.uniform(high=1000,
|
|
size=tuple(shape)).astype(integer_dict[initializer.type])
|
|
elif initializer.type == 'tensor(bool)':
|
|
feeds[initializer.name] = np.random.randint(2, size=tuple(shape)).astype('bool')
|
|
else:
|
|
print("unsupported initializer type {} for initializer {}".format(initializer.type, initializer.name))
|
|
sys.exit(-1)
|
|
|
|
start = timer()
|
|
for i in range(iters):
|
|
sess.run([], feeds) # fetch all outputs
|
|
end = timer()
|
|
|
|
print("model: {}".format(meta.graph_name))
|
|
print("version: {}".format(meta.version))
|
|
print("iterations: {}".format(iters))
|
|
print("avg latency: {} ms".format(((end - start) * 1000) / iters))
|
|
|
|
if args.profile:
|
|
trace_file = sess.end_profiling()
|
|
print("trace file written to: {}".format(trace_file))
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|