no fallback when enforcing explicit EP registration. (#9863)

* no fallback when enforcing explicit EP registration.

* add explicit ep registrations for python.
This commit is contained in:
George Wu 2021-11-25 07:26:51 -08:00 committed by GitHub
parent a3ebc5e082
commit 1e9e57df3e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 3 deletions

View file

@ -121,7 +121,7 @@ def perf_test(rnn_type, num_threads, input_dim, hidden_dim, bidirectional, layer
convert_to_scan_model(model_name, scan_model_name)
# note that symbolic shape inference is needed because model has symbolic batch dim, thus init_state is ConstantOfShape
onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(scan_model_name)), scan_model_name)
sess = onnxruntime.InferenceSession(scan_model_name)
sess = onnxruntime.InferenceSession(scan_model_name, providers=onnxruntime.get_available_providers())
count, duration, per_iter_cost = perf_run(sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds)
avg_scan = top_n_avg(per_iter_cost, top_n)
print('perf_scan (with {} threads) {}: run for {} iterations, top {} avg {:.3f} ms'.format(num_threads, scan_model_name, count, top_n, avg_scan))
@ -131,7 +131,7 @@ def perf_test(rnn_type, num_threads, input_dim, hidden_dim, bidirectional, layer
int8_model_name = os.path.splitext(model_name)[0] + '_int8.onnx'
convert_matmul_model(scan_model_name, int8_model_name)
onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(int8_model_name)), int8_model_name)
sess = onnxruntime.InferenceSession(int8_model_name)
sess = onnxruntime.InferenceSession(int8_model_name, providers=onnxruntime.get_available_providers())
count, duration, per_iter_cost = perf_run(sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds)
avg_int8 = top_n_avg(per_iter_cost, top_n)
print('perf_int8 (with {} threads) {}: run for {} iterations, top {} avg {:.3f} ms'.format(num_threads, int8_model_name, count, top_n, avg_int8))

View file

@ -357,6 +357,7 @@ class InferenceSession(Session):
provider_options,
available_providers)
if providers == [] and len(available_providers) > 1:
self.disable_fallback()
raise ValueError("This ORT build has {} enabled. ".format(available_providers) +
"Since ORT 1.9, you are required to explicitly set " +
"the providers parameter when instantiating InferenceSession. For example, "

View file

@ -71,7 +71,7 @@ def run_model(model_path,
sess_options.enable_profiling = True
sess_options.profile_file_prefix = os.path.basename(model_path)
sess = onnxrt.InferenceSession(model_path, sess_options)
sess = onnxrt.InferenceSession(model_path, sess_options=sess_options, providers=onnxrt.get_available_providers())
meta = sess.get_modelmeta()
if not feeds: