From Python 3.8 and on you need to explicitly add the current directory for libraries to be loaded from it. Update onnxruntime_test_python.py with that handling. (#10129)

This commit is contained in:
Scott McKay 2021-12-28 16:10:26 +10:00 committed by GitHub
parent 3d6786c92e
commit a367f0664d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2,17 +2,23 @@
# Licensed under the MIT License.
# -*- coding: UTF-8 -*-
import unittest
import os
import numpy as np
import gc
import numpy as np
import onnxruntime as onnxrt
import threading
import os
import platform
import sys
import threading
import unittest
from helper import get_name
from onnxruntime.capi.onnxruntime_pybind11_state import Fail
# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed.
if platform.system() == 'Windows' and sys.version_info.major >= 3 and sys.version_info.minor >= 8:
os.add_dll_directory(os.getcwd())
class TestInferenceSession(unittest.TestCase):
def run_model(self, session_object, run_options):
@ -887,7 +893,7 @@ class TestInferenceSession(unittest.TestCase):
def testOrtValue_ghIssue9799(self):
if 'CUDAExecutionProvider' in onnxrt.get_available_providers():
session = onnxrt.InferenceSession(get_name("identity_9799.onnx"),
session = onnxrt.InferenceSession(get_name("identity_9799.onnx"),
providers=onnxrt.get_available_providers())
for seq_length in range(40, 200):
@ -1105,7 +1111,7 @@ class TestInferenceSession(unittest.TestCase):
else:
shared_library = './libtest_execution_provider.so'
if not os.path.exists(shared_library):
raise FileNotFoundError("Unable to find '{0}'".format(shared_library))
@ -1116,9 +1122,9 @@ class TestInferenceSession(unittest.TestCase):
session_options = C.get_default_session_options()
sess = C.InferenceSession(session_options, custom_op_model, True, True)
sess.initialize_session(['my_ep'],
sess.initialize_session(['my_ep'],
[{'shared_lib_path': shared_library,
'device_id':'1', 'some_config':'val'}],
'device_id':'1', 'some_config':'val'}],
set())
print("Create session with customize execution provider successfully!")