From 4dc0ddf606de526ef62afaeb8d285d4786ae3f1b Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Tue, 31 Aug 2021 20:51:22 -0700 Subject: [PATCH] support register external ep lib information (#8897) * support register external ep lib inforation; make eager mode share the same ep pools with training workloads * fix inference code * fix build break * fix the message --- .../core/eager/ort_kernel_invoker.h | 4 +- .../python/onnxruntime_pybind_module.cc | 8 ++ .../python/onnxruntime_pybind_state.cc | 21 ++-- .../test/python/onnxruntime_test_python.py | 1 - .../exported_symbols.lst | 2 +- .../my_ep_factory.cc | 2 +- .../my_ep_factory.h | 2 +- .../symbols.def | 2 +- .../version_script.lds | 2 +- .../orttraining/eager/ort_backends.cpp | 49 ++------ orttraining/orttraining/eager/ort_backends.h | 5 - orttraining/orttraining/eager/ort_eager.cpp | 9 -- .../orttraining/eager/test/ort_eps_test.py | 2 +- .../python/orttraining_python_module.cc | 118 ++++++++++++++---- .../python/onnxruntime_test_register_ep.py | 29 +++++ 15 files changed, 156 insertions(+), 100 deletions(-) create mode 100644 orttraining/orttraining/test/python/onnxruntime_test_register_ep.py diff --git a/include/onnxruntime/core/eager/ort_kernel_invoker.h b/include/onnxruntime/core/eager/ort_kernel_invoker.h index 500ffdfb9a..cbe7a43acc 100644 --- a/include/onnxruntime/core/eager/ort_kernel_invoker.h +++ b/include/onnxruntime/core/eager/ort_kernel_invoker.h @@ -22,7 +22,7 @@ namespace onnxruntime { class ORTInvoker { public: - ORTInvoker(std::unique_ptr execution_provider, + ORTInvoker(std::shared_ptr execution_provider, const logging::Logger& logger, const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) : execution_provider_(std::move(execution_provider)), logger_(logger), custom_op_registries_(custom_op_registries) { @@ -44,7 +44,7 @@ class ORTInvoker { const int version = -1); private: - std::unique_ptr execution_provider_; + std::shared_ptr execution_provider_; const logging::Logger& logger_; // custom ops for current execution provider // we need the op schema to resolve the output type during invoke diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index 98017d981a..378117b096 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include +#include +#include "core/providers/get_execution_providers.h" namespace onnxruntime { namespace python { @@ -11,6 +13,12 @@ void CreateInferencePybindStateModule(py::module& m); PYBIND11_MODULE(onnxruntime_pybind11_state, m) { CreateInferencePybindStateModule(m); + // move it out of shared method since training build has a little different behavior. + m.def( + "get_available_providers", []() -> const std::vector& { return GetAvailableExecutionProviderNames(); }, + "Return list of available Execution Providers in this installed version of Onnxruntime. " + "The order of elements represents the default priority order of Execution Providers " + "from highest to lowest."); } } } \ No newline at end of file diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 750da8fbd2..f6dedb3ece 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -15,7 +15,6 @@ #include "core/framework/arena_extend_strategy.h" #include "core/framework/data_transfer_utils.h" #include "core/framework/data_types_internal.h" -#include "core/providers/get_execution_providers.h" #include "core/framework/provider_options_utils.h" #include "core/framework/random_seed.h" #include "core/framework/sparse_tensor.h" @@ -23,6 +22,7 @@ #include "core/framework/TensorSeq.h" #include "core/graph/graph_viewer.h" #include "core/platform/env.h" +#include "core/providers/get_execution_providers.h" #include "core/session/IOBinding.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -338,12 +338,12 @@ const ROCMExecutionProviderInfo GetROCMExecutionProviderInfo(const ProviderOptio #endif std::unique_ptr CreateExecutionProviderInstance( - InferenceSession* sess, + const SessionOptions& session_options, const std::string& type, const ProviderOptionsMap& provider_options_map){ if (type == kCpuExecutionProvider) { return onnxruntime::CreateExecutionProviderFactory_CPU( - sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider(); + session_options.enable_cpu_mem_arena)->CreateProvider(); } else if (type == kTensorrtExecutionProvider) { #ifdef USE_TENSORRT std::string calibration_table, cache_path, lib_path; @@ -530,7 +530,7 @@ std::unique_ptr CreateExecutionProviderInstance( } else if (type == kDnnlExecutionProvider) { #ifdef USE_DNNL return onnxruntime::CreateExecutionProviderFactory_Dnnl( - sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider(); + session_options.enable_cpu_mem_arena)->CreateProvider(); #endif } else if (type == kOpenVINOExecutionProvider) { #ifdef USE_OPENVINO @@ -628,12 +628,12 @@ std::unique_ptr CreateExecutionProviderInstance( } else if (type == kAclExecutionProvider) { #ifdef USE_ACL return onnxruntime::CreateExecutionProviderFactory_ACL( - sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider(); + session_options.enable_cpu_mem_arena)->CreateProvider(); #endif } else if (type == kArmNNExecutionProvider) { #ifdef USE_ARMNN return onnxruntime::CreateExecutionProviderFactory_ArmNN( - sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider(); + session_options.enable_cpu_mem_arena)->CreateProvider(); #endif } else if (type == kDmlExecutionProvider) { #ifdef USE_DML @@ -655,7 +655,7 @@ std::unique_ptr CreateExecutionProviderInstance( #if !defined(__ANDROID__) LOGS_DEFAULT(WARNING) << "NNAPI execution provider can only be used to generate ORT format model in this build."; #endif - const auto partitioning_stop_ops_list = sess->GetSessionOptions().config_options.GetConfigEntry( + const auto partitioning_stop_ops_list = session_options.config_options.GetConfigEntry( kOrtSessionOptionsConfigNnapiEpPartitioningStopOps); return onnxruntime::CreateExecutionProviderFactory_Nnapi(0, partitioning_stop_ops_list)->CreateProvider(); #endif @@ -705,7 +705,7 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector ORT_UNUSED_PARAMETER(provider_options_map); for (const std::string& type : provider_types) { - auto ep = CreateExecutionProviderInstance(sess, type, provider_options_map); + auto ep = CreateExecutionProviderInstance(sess->GetSessionOptions(), type, provider_options_map); if (ep) OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(ep))); } @@ -833,11 +833,6 @@ void addGlobalMethods(py::module& m, Environment& env) { "Return list of Execution Providers that this version of Onnxruntime can support. " "The order of elements represents the default priority order of Execution Providers " "from highest to lowest."); - m.def( - "get_available_providers", []() -> const std::vector& { return GetAvailableExecutionProviderNames(); }, - "Return list of available Execution Providers available in this installed version of Onnxruntime. " - "The order of elements represents the default priority order of Execution Providers " - "from highest to lowest."); m.def( "enable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().EnableTelemetryEvents(); }, "Enables platform-specific telemetry collection where applicable."); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 045cb27c88..52e34e08fd 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1092,7 +1092,6 @@ class TestInferenceSession(unittest.TestCase): sess = C.InferenceSession(session_options, custom_op_model, True, True) sess.initialize_session(['my_ep'], [{'shared_lib_path': shared_library, - 'provider_factory_entry_point' : 'ProviderEntryPoint', 'device_id':'1', 'some_config':'val'}], set()) print("Create session with customize execution provider successfully!") diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst b/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst index d1c294b006..f4c4141259 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst +++ b/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst @@ -1 +1 @@ -_ProviderEntryPoint +_GetProvider diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc index f48b1340e1..aa1a9b6f4a 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc @@ -64,7 +64,7 @@ struct MyEP_Provider : Provider { extern "C" { -ORT_API(onnxruntime::Provider*, ProviderEntryPoint) { +ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h index 3c19549311..b342960935 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h @@ -10,7 +10,7 @@ extern "C" { ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MyEP, _In_ OrtSessionOptions* options, int device_id); -ORT_API(onnxruntime::Provider*, ProviderEntryPoint); +ORT_API(onnxruntime::Provider*, GetProvider); #ifdef __cplusplus } diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def b/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def index 44d2bdf70f..4ec2f7914c 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def +++ b/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def @@ -1,2 +1,2 @@ EXPORTS - ProviderEntryPoint + GetProvider diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds b/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds index 1ddb548065..b298a6d003 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds +++ b/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds @@ -1,7 +1,7 @@ #_init and _fini should be local VERS_1.0 { global: - ProviderEntryPoint; + GetProvider; # Hide everything else. local: diff --git a/orttraining/orttraining/eager/ort_backends.cpp b/orttraining/orttraining/eager/ort_backends.cpp index 944432ccea..fc32f1d06a 100644 --- a/orttraining/orttraining/eager/ort_backends.cpp +++ b/orttraining/orttraining/eager/ort_backends.cpp @@ -15,6 +15,9 @@ namespace onnxruntime{ namespace python{ Environment& GetTrainingORTEnv(); + std::shared_ptr GetOrCreateExecutionProvider(const std::string& provider_type, + const ProviderOptionsMap& provider_options_map, + const SessionOptions& session_options); } } @@ -41,53 +44,15 @@ ORTBackendsManager::ORTBackendsManager(const onnxruntime::logging::Logger& logge } } -void ORTBackendsManager::RegisterProviderLib(const std::string& provider_type, - const std::string& lib_path, - const std::string& entry_point){ - additional_provider_libs_.insert({provider_type, {lib_path, entry_point}}); -} - onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const std::string& provider_type, const ProviderOptions& provider_options){ - // query avalible device - auto& available_providers = GetAvailableExecutionProviderNames(); - std::unique_ptr provider_p; - if (std::find(available_providers.begin(), available_providers.end(), provider_type) != available_providers.end()){ - if (provider_type == kCpuExecutionProvider){ - provider_p = onnxruntime::CreateExecutionProviderFactory_CPU(0)->CreateProvider(); - } - } - else{ - auto shared_lib_path_it = additional_provider_libs_.find(provider_type); - if (shared_lib_path_it == additional_provider_libs_.end()){ - return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME, - common::StatusCode::INVALID_ARGUMENT, - "Execution provider: " + provider_type + " is not supported."); - } - - void* handle; - auto lib_path = shared_lib_path_it->second.first; - auto entry_point = shared_lib_path_it->second.second; - auto error = Env::Default().LoadDynamicLibrary(lib_path, false, &handle); - if (!error.IsOK()) { - return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME, - common::StatusCode::INVALID_ARGUMENT, - "Load shared execution provider: " + provider_type + " failed: " - + error.ErrorMessage()); - } - - Provider* (*PGetProvider)(); - ORT_RETURN_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle, entry_point, (void**)&PGetProvider)); - - Provider* provider = PGetProvider(); - std::shared_ptr ep_factory = provider->CreateExecutionProviderFactory(&provider_options); - provider_p = ep_factory->CreateProvider(); - } - + auto ep = onnxruntime::python::GetOrCreateExecutionProvider(provider_type, + ProviderOptionsMap{{provider_type, provider_options}}, + SessionOptions{}); auto invoker = std::make_unique( - std::move(provider_p), + std::move(ep), logger_, custom_op_schema_); diff --git a/orttraining/orttraining/eager/ort_backends.h b/orttraining/orttraining/eager/ort_backends.h index 5cd77f272d..a5090b1373 100644 --- a/orttraining/orttraining/eager/ort_backends.h +++ b/orttraining/orttraining/eager/ort_backends.h @@ -19,10 +19,6 @@ class ORTBackendsManager { public: ORTBackendsManager(const onnxruntime::logging::Logger& logger); - void RegisterProviderLib(const std::string& provider_type, - const std::string& lib_path, - const std::string& entry_point); - onnxruntime::Status set_device(size_t device_index, const std::string& provider_type, const onnxruntime::ProviderOptions& provider_options); @@ -34,7 +30,6 @@ private: //custom op schema registry //TODO: we might want to support load custom op schema on the fly onnxruntime::IOnnxRuntimeOpSchemaRegistryList custom_op_schema_ = {}; - std::unordered_map > additional_provider_libs_ = {}; }; ORTBackendsManager& GetORTBackendsManager(); diff --git a/orttraining/orttraining/eager/ort_eager.cpp b/orttraining/orttraining/eager/ort_eager.cpp index 36ebf85ae9..556e8751a8 100644 --- a/orttraining/orttraining/eager/ort_eager.cpp +++ b/orttraining/orttraining/eager/ort_eager.cpp @@ -52,15 +52,6 @@ void addObjectMethodsForEager(py::module& m){ return ORTTensor_FromDLPack(dlpack_tensor); }); - m.def("_register_provider_lib", [](const std::string& name, - const std::string& provider_shared_lib_path, - const std::string& provider_factory_entry) { - torch_ort::eager::GetORTBackendsManager().RegisterProviderLib(name, provider_shared_lib_path, provider_factory_entry); - }, - py::arg("name"), - py::arg("provider_shared_lib_path"), - py::arg("provider_factory_entry") = kDefaultExecutionProviderEntry); - m.def("set_device", [](size_t device_index, const std::string& provider_type, const std::unordered_map& arguments){ diff --git a/orttraining/orttraining/eager/test/ort_eps_test.py b/orttraining/orttraining/eager/test/ort_eps_test.py index f7c63f8587..f1cc141fef 100644 --- a/orttraining/orttraining/eager/test/ort_eps_test.py +++ b/orttraining/orttraining/eager/test/ort_eps_test.py @@ -13,7 +13,7 @@ class OrtEPTests(unittest.TestCase): def test_import_custom_eps(self): torch_ort.set_device(0, 'CPUExecutionProvider', {}) - torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), 'ProviderEntryPoint') + torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), {}) torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'}) ort_device = torch_ort.device(1) diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 6f771c5d73..f8aa62bbc8 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -7,6 +7,7 @@ #include "core/common/logging/logging.h" #include "core/common/logging/severity.h" +#include "core/providers/get_execution_providers.h" #include "core/platform/env.h" #include "core/session/provider_bridge_ort.h" @@ -20,7 +21,7 @@ namespace py = pybind11; using namespace onnxruntime::logging; std::unique_ptr CreateExecutionProviderInstance( - InferenceSession* sess, + const SessionOptions& session_options, const std::string& type, const ProviderOptionsMap& provider_options_map); @@ -41,6 +42,7 @@ void addObjectMethodsForEager(py::module& m); void InitArray(); using ExecutionProviderMap = std::unordered_map >; +using ExecutionProviderLibInfoMap = std::unordered_map > ; bool GetDyanmicExecutionProviderHash( const std::string& ep_shared_lib_path, @@ -137,6 +139,8 @@ public: Severity::kWARNING, false, LoggingManager::InstanceType::Default, &SessionObjectInitializer::default_logger_id), ort_env_)); + auto& builtinEPs = GetAvailableExecutionProviderNames(); + available_training_eps_.assign(builtinEPs.begin(), builtinEPs.end()); } Environment& GetORTEnv(){ @@ -156,6 +160,20 @@ public: std::move(execution_provider)}); } + void RegisterExtExecutionProviderInfo(const std::string& provider_type, + const std::string& provider_lib_path, + const ProviderOptions& default_options){ + ext_execution_provider_info_map_.insert({provider_type, {provider_lib_path, default_options}}); + if (std::find(available_training_eps_.begin(), available_training_eps_.end(), provider_type) == available_training_eps_.end()) + available_training_eps_.push_back(provider_type); + } + + const std::vector& GetAvailableTrainingExecutionProviderTypes(){ + return available_training_eps_; + } + + ExecutionProviderLibInfoMap ext_execution_provider_info_map_; + private: std::string GetExecutionProviderMapKey(const std::string& provider_type, size_t hash){ @@ -166,6 +184,7 @@ private: std::unique_ptr ort_env_; ExecutionProviderMap execution_provider_instances_map_; + std::vector available_training_eps_; }; static std::unique_ptr ort_training_env; @@ -199,30 +218,72 @@ Environment& GetTrainingORTEnv() { return ort_training_env->GetORTEnv(); } +void ResolveExtraProviderOptions(const std::string& provider_type, + const ProviderOptionsMap& original_provider_options_map, + ProviderOptionsMap& merged_options){ + auto& training_env = GetTrainingEnv(); + auto it = training_env.ext_execution_provider_info_map_.find(provider_type); + if (it == training_env.ext_execution_provider_info_map_.end()){ + //nothing changed. + merged_options = original_provider_options_map; + }else{ + ProviderOptions options = it->second.second; + options.insert({kExecutionProviderSharedLibraryPath, it->second.first}); + auto original_map_it = original_provider_options_map.find(provider_type); + if (original_map_it != original_provider_options_map.end()){ + for (auto [k, v] : original_map_it->second){ + options.insert({k, v}); + } + } + merged_options[provider_type] = options; + } +} + +std::unique_ptr CreateTrainingEP( + const SessionOptions& session_options, + const std::string& provider_type, + const ProviderOptionsMap& provider_options_map){ + ORTTrainingPythonEnv& training_env = GetTrainingEnv(); + if (training_env.ext_execution_provider_info_map_.find(provider_type) != + training_env.ext_execution_provider_info_map_.end()){ + ProviderOptionsMap merged_options; + ResolveExtraProviderOptions(provider_type, provider_options_map, merged_options); + return CreateExecutionProviderInstance(session_options, provider_type, merged_options); + }else{ + return CreateExecutionProviderInstance(session_options, provider_type, provider_options_map); + } +} + +std::shared_ptr GetOrCreateExecutionProvider(const std::string& provider_type, + const ProviderOptionsMap& provider_options_map, + const SessionOptions& session_options){ + ORTTrainingPythonEnv& training_env = GetTrainingEnv(); + // search in environment + size_t hash; + if (GetProviderInstanceHash(provider_type, provider_options_map, hash)){ + auto cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash); + if (!cached_provider_instance){ + auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map); + if (ep){ + training_env.AddExecutionProvider(provider_type, hash, std::move(ep)); + cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash); + } + } + return cached_provider_instance; + } + else{ + // the EP doesn't support cache, register the instance to session + auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map); + return ep; + } +} + void ORTTrainingRegisterExecutionProviders(InferenceSession* sess, const std::vector& provider_types, const ProviderOptionsMap& provider_options_map) { - // search in environment - ORTTrainingPythonEnv& training_env = GetTrainingEnv(); for (auto provider_type : provider_types){ - size_t hash; - if (GetProviderInstanceHash(provider_type, provider_options_map, hash)){ - auto cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash); - if (!cached_provider_instance){ - auto ep = CreateExecutionProviderInstance(sess, provider_type, provider_options_map); - if (ep){ - training_env.AddExecutionProvider(provider_type, hash, std::move(ep)); - cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash); - } - } - if (cached_provider_instance) - OrtPybindThrowIfError(sess->RegisterExecutionProvider(cached_provider_instance)); - } - else{ - // the EP doesn't support cache, register the instance to session - auto ep = CreateExecutionProviderInstance(sess, provider_type, provider_options_map); - if (ep) - OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(ep))); - } + auto ep = GetOrCreateExecutionProvider(provider_type, provider_options_map, sess->GetSessionOptions()); + if (ep) + OrtPybindThrowIfError(sess->RegisterExecutionProvider(ep)); } } @@ -250,6 +311,19 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { #ifdef ENABLE_EAGER_MODE addObjectMethodsForEager(m); #endif + + m.def("_register_provider_lib", [](const std::string& name, + const std::string& provider_shared_lib_path, + const ProviderOptions& default_options) { + GetTrainingEnv().RegisterExtExecutionProviderInfo(name, provider_shared_lib_path, default_options); + }); + + m.def( + "get_available_providers", []() -> const std::vector& { + return GetTrainingEnv().GetAvailableTrainingExecutionProviderTypes(); }, + "Return list of available Execution Providers in this installed version of Onnxruntime. " + "The order of elements represents the default priority order of Execution Providers " + "from highest to lowest."); // clean the ort training environment when python interpreter exit // otherwise the global var will be de-constrcut after user main. diff --git a/orttraining/orttraining/test/python/onnxruntime_test_register_ep.py b/orttraining/orttraining/test/python/onnxruntime_test_register_ep.py new file mode 100644 index 0000000000..7b9d87dcc6 --- /dev/null +++ b/orttraining/orttraining/test/python/onnxruntime_test_register_ep.py @@ -0,0 +1,29 @@ +import unittest +import onnxruntime_pybind11_state as C +import os + +class EPRegistrationTests(unittest.TestCase): + def get_test_execution_provider_path(self): + return os.path.join('.', 'libtest_execution_provider.so') + + def test_register_custom_eps(self): + C._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), {'some_config':'val'}) + + assert 'TestExecutionProvider' in C.get_available_providers() + + this = os.path.dirname(__file__) + custom_op_model = os.path.join(this, "testdata", "custom_execution_provider_library", "test_model.onnx") + if not os.path.exists(custom_op_model): + raise FileNotFoundError("Unable to find '{0}'".format(custom_op_model)) + + session_options = C.get_default_session_options() + sess = C.InferenceSession(session_options, custom_op_model, True, True) + sess.initialize_session(['TestExecutionProvider'], + [{'device_id':'0'}], + set()) + print("Created session with customize execution provider successfully!") + + +if __name__ == '__main__': + unittest.main() +