mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
95 lines
3.2 KiB
Python
95 lines
3.2 KiB
Python
#-------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
#--------------------------------------------------------------------------
|
|
|
|
import argparse
|
|
import onnxruntime as onnxrt
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
from timeit import default_timer as timer
|
|
|
|
float_dict = {
|
|
'tensor(float16)': 'float16',
|
|
'tensor(float)': 'float32',
|
|
'tensor(double)': 'float64'
|
|
}
|
|
|
|
integer_dict = {
|
|
'tensor(int32)': 'int32',
|
|
'tensor(int8)': 'int8',
|
|
'tensor(uint8)': 'uint8',
|
|
'tensor(int16)': 'int16',
|
|
'tensor(uint16)': 'uint16',
|
|
'tensor(int64)': 'int64',
|
|
'tensor(uint64)': 'uint64'
|
|
}
|
|
|
|
# simple test program for loading onnx model, feeding all inputs and running the model num_iters times.
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Simple ONNX Runtime Test Tool.')
|
|
parser.add_argument('model_path', help='model path')
|
|
parser.add_argument('num_iters', nargs='?', type=int,
|
|
default=1000, help='model run iterations. default=1000')
|
|
parser.add_argument('--debug', action='store_true',
|
|
help='pause execution to allow attaching a debugger.')
|
|
parser.add_argument('--profile', action='store_true',
|
|
help='enable chrome timeline trace profiling.')
|
|
args = parser.parse_args()
|
|
iters = args.num_iters
|
|
|
|
if args.debug:
|
|
print("Pausing execution ready for debugger to attach to pid: {}".format(
|
|
os.getpid()))
|
|
print("Press key to continue.")
|
|
sys.stdin.read(1)
|
|
|
|
sess_options = None
|
|
if args.profile:
|
|
sess_options = onnxrt.SessionOptions()
|
|
sess_options.enable_profiling = True
|
|
sess_options.profile_file_prefix = os.path.basename(args.model_path)
|
|
|
|
sess = onnxrt.InferenceSession(args.model_path, sess_options)
|
|
meta = sess.get_modelmeta()
|
|
|
|
feeds = {}
|
|
for input_meta in sess.get_inputs():
|
|
# replace any symbolic dimensions (value is None) with 1
|
|
shape = [dim if dim else 1 for dim in input_meta.shape]
|
|
if input_meta.type in float_dict:
|
|
feeds[input_meta.name] = np.random.rand(
|
|
*shape).astype(float_dict[input_meta.type])
|
|
elif input_meta.type in integer_dict:
|
|
feeds[input_meta.name] = np.random.uniform(
|
|
high=1000, size=tuple(shape)).astype(integer_dict[input_meta.type])
|
|
elif input_meta.type == 'tensor(bool)':
|
|
feeds[input_meta.name] = np.random.randint(
|
|
2, size=tuple(shape)).astype('bool')
|
|
else:
|
|
print("unsupported input type {} for input {}".format(
|
|
input_meta.type, input_meta.name))
|
|
sys.exit(-1)
|
|
|
|
start = timer()
|
|
for i in range(iters):
|
|
sess.run([], feeds) # fetch all outputs
|
|
end = timer()
|
|
|
|
print("model: {}".format(meta.graph_name))
|
|
print("version: {}".format(meta.version))
|
|
print("iterations: {}".format(iters))
|
|
print("avg latency: {} ms".format(((end - start)*1000)/iters))
|
|
|
|
if args.profile:
|
|
trace_file = sess.end_profiling()
|
|
print("trace file written to: {}".format(trace_file))
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|