mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
no fallback when enforcing explicit EP registration. (#9863)
* no fallback when enforcing explicit EP registration. * add explicit ep registrations for python.
This commit is contained in:
parent
a3ebc5e082
commit
1e9e57df3e
3 changed files with 4 additions and 3 deletions
|
|
@ -121,7 +121,7 @@ def perf_test(rnn_type, num_threads, input_dim, hidden_dim, bidirectional, layer
|
|||
convert_to_scan_model(model_name, scan_model_name)
|
||||
# note that symbolic shape inference is needed because model has symbolic batch dim, thus init_state is ConstantOfShape
|
||||
onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(scan_model_name)), scan_model_name)
|
||||
sess = onnxruntime.InferenceSession(scan_model_name)
|
||||
sess = onnxruntime.InferenceSession(scan_model_name, providers=onnxruntime.get_available_providers())
|
||||
count, duration, per_iter_cost = perf_run(sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds)
|
||||
avg_scan = top_n_avg(per_iter_cost, top_n)
|
||||
print('perf_scan (with {} threads) {}: run for {} iterations, top {} avg {:.3f} ms'.format(num_threads, scan_model_name, count, top_n, avg_scan))
|
||||
|
|
@ -131,7 +131,7 @@ def perf_test(rnn_type, num_threads, input_dim, hidden_dim, bidirectional, layer
|
|||
int8_model_name = os.path.splitext(model_name)[0] + '_int8.onnx'
|
||||
convert_matmul_model(scan_model_name, int8_model_name)
|
||||
onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(int8_model_name)), int8_model_name)
|
||||
sess = onnxruntime.InferenceSession(int8_model_name)
|
||||
sess = onnxruntime.InferenceSession(int8_model_name, providers=onnxruntime.get_available_providers())
|
||||
count, duration, per_iter_cost = perf_run(sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds)
|
||||
avg_int8 = top_n_avg(per_iter_cost, top_n)
|
||||
print('perf_int8 (with {} threads) {}: run for {} iterations, top {} avg {:.3f} ms'.format(num_threads, int8_model_name, count, top_n, avg_int8))
|
||||
|
|
|
|||
|
|
@ -357,6 +357,7 @@ class InferenceSession(Session):
|
|||
provider_options,
|
||||
available_providers)
|
||||
if providers == [] and len(available_providers) > 1:
|
||||
self.disable_fallback()
|
||||
raise ValueError("This ORT build has {} enabled. ".format(available_providers) +
|
||||
"Since ORT 1.9, you are required to explicitly set " +
|
||||
"the providers parameter when instantiating InferenceSession. For example, "
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def run_model(model_path,
|
|||
sess_options.enable_profiling = True
|
||||
sess_options.profile_file_prefix = os.path.basename(model_path)
|
||||
|
||||
sess = onnxrt.InferenceSession(model_path, sess_options)
|
||||
sess = onnxrt.InferenceSession(model_path, sess_options=sess_options, providers=onnxrt.get_available_providers())
|
||||
meta = sess.get_modelmeta()
|
||||
|
||||
if not feeds:
|
||||
|
|
|
|||
Loading…
Reference in a new issue