From de2a53e46dcb0192c6961e3a91ffe7c3bbc8d34b Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Wed, 11 Aug 2021 15:10:35 -0700 Subject: [PATCH] [eager mode] fix build and support customize shared provider entry point (#8680) * fix build break * support customize the name of shared provide lib's entry point * fix non training build * check error code * check return code --- onnxruntime/core/eager/ort_kernel_invoker.cc | 7 ++++++- onnxruntime/python/onnxruntime_pybind_state.cc | 16 ++++++++++++---- .../python/onnxruntime_pybind_state_common.h | 1 + .../test/python/onnxruntime_test_python.py | 1 + .../exported_symbols.lst | 2 +- .../my_ep_factory.cc | 2 +- .../my_ep_factory.h | 2 +- .../symbols.def | 2 +- .../version_script.lds | 2 +- orttraining/orttraining/eager/ort_aten.cpp | 2 +- orttraining/orttraining/eager/ort_backends.cpp | 14 +++++++++----- orttraining/orttraining/eager/ort_backends.h | 6 ++++-- orttraining/orttraining/eager/ort_eager.cpp | 11 ++++++++--- .../orttraining/eager/test/ort_eps_test.py | 2 +- 14 files changed, 48 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/eager/ort_kernel_invoker.cc b/onnxruntime/core/eager/ort_kernel_invoker.cc index f64dd0eaa9..ce5bf6dea9 100644 --- a/onnxruntime/core/eager/ort_kernel_invoker.cc +++ b/onnxruntime/core/eager/ort_kernel_invoker.cc @@ -7,9 +7,12 @@ #include "core/graph/model.h" #include "core/framework/op_kernel.h" #include "core/session/ort_env.h" +#include "core/graph/constants.h" namespace onnxruntime { +#define ORT_EAGER_ONNX_OPSET_VERSION 14 + common::Status ORTInvoker::Invoke(const std::string& op_name, //optional inputs / outputs? const std::vector& inputs, @@ -17,13 +20,15 @@ common::Status ORTInvoker::Invoke(const std::string& op_name, const NodeAttributes* attributes, const std::string& domain, const int version) { + std::unordered_map domain_version_map = {{kOnnxDomain, ORT_EAGER_ONNX_OPSET_VERSION}, + {kMSDomain, 1}}; //create a graph Model model("test", false, ModelMetaData(), ORT_TSTR(""), custom_op_registries_, - {}, + domain_version_map, {}, logger_); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index c7d08bd478..e049e94c04 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -46,6 +46,8 @@ const OrtDevice::DeviceType OrtDevice::GPU; namespace onnxruntime { constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path"; +constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point"; + } // namespace onnxruntime #if defined(_MSC_VER) @@ -299,7 +301,8 @@ static inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime static std::unique_ptr LoadExecutionProvider( const std::string& ep_shared_lib_path, - const ProviderOptions& provider_options = {}) { + const ProviderOptions& provider_options = {}, + const std::string& entry_symbol_name = "GetProvider") { void* handle; auto error = Env::Default().LoadDynamicLibrary(ep_shared_lib_path, false, &handle); if (!error.IsOK()) { @@ -307,7 +310,7 @@ static std::unique_ptr LoadExecutionProvider( } Provider* (*PGetProvider)(); - Env::Default().GetSymbolFromLibrary(handle, "GetProvider", (void**)&PGetProvider); + OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, entry_symbol_name, (void**)&PGetProvider)); Provider* provider = PGetProvider(); std::shared_ptr ep_factory = provider->CreateExecutionProviderFactory(&provider_options); @@ -681,11 +684,16 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector // this is an EP with dynamic loading // construct the provider option ProviderOptions provider_options; + std::string entry_symbol = kDefaultExecutionProviderEntry; for (auto option : it->second) { - if (option.first != kExecutionProviderSharedLibraryPath) + if (option.first == kExecutionProviderSharedLibraryEntry){ + entry_symbol = option.second; + } + else if (option.first != kExecutionProviderSharedLibraryPath){ provider_options.insert(option); + } } - auto p_ep = LoadExecutionProvider(shared_lib_path_it->second, provider_options); + auto p_ep = LoadExecutionProvider(shared_lib_path_it->second, provider_options, entry_symbol); ORT_THROW_IF_ERROR(sess->RegisterExecutionProvider( std::move(p_ep))); continue; diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 47645642ed..8c30d61f9b 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -454,4 +454,5 @@ std::shared_ptr CreateExecutionProviderFactory_Nnapi( std::shared_ptr CreateExecutionProviderFactory_Rknpu(); std::shared_ptr CreateExecutionProviderFactory_CoreML(uint32_t flags); +constexpr const char* kDefaultExecutionProviderEntry = "GetProvider"; } // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 52e34e08fd..045cb27c88 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1092,6 +1092,7 @@ class TestInferenceSession(unittest.TestCase): sess = C.InferenceSession(session_options, custom_op_model, True, True) sess.initialize_session(['my_ep'], [{'shared_lib_path': shared_library, + 'provider_factory_entry_point' : 'ProviderEntryPoint', 'device_id':'1', 'some_config':'val'}], set()) print("Create session with customize execution provider successfully!") diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst b/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst index f4c4141259..d1c294b006 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst +++ b/onnxruntime/test/testdata/custom_execution_provider_library/exported_symbols.lst @@ -1 +1 @@ -_GetProvider +_ProviderEntryPoint diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc index aa1a9b6f4a..f48b1340e1 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.cc @@ -64,7 +64,7 @@ struct MyEP_Provider : Provider { extern "C" { -ORT_API(onnxruntime::Provider*, GetProvider) { +ORT_API(onnxruntime::Provider*, ProviderEntryPoint) { return &onnxruntime::g_provider; } diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h index b342960935..3c19549311 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h +++ b/onnxruntime/test/testdata/custom_execution_provider_library/my_ep_factory.h @@ -10,7 +10,7 @@ extern "C" { ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MyEP, _In_ OrtSessionOptions* options, int device_id); -ORT_API(onnxruntime::Provider*, GetProvider); +ORT_API(onnxruntime::Provider*, ProviderEntryPoint); #ifdef __cplusplus } diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def b/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def index 4ec2f7914c..44d2bdf70f 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def +++ b/onnxruntime/test/testdata/custom_execution_provider_library/symbols.def @@ -1,2 +1,2 @@ EXPORTS - GetProvider + ProviderEntryPoint diff --git a/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds b/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds index b298a6d003..1ddb548065 100644 --- a/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds +++ b/onnxruntime/test/testdata/custom_execution_provider_library/version_script.lds @@ -1,7 +1,7 @@ #_init and _fini should be local VERS_1.0 { global: - GetProvider; + ProviderEntryPoint; # Hide everything else. local: diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index f10b9c027d..0caa145a35 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -273,7 +273,7 @@ at::Tensor& zero_(at::Tensor& self){ "ZeroGradient", { std::move(ort_in_self), std::move(flag_val) - }, ort_out, nullptr, onnxruntime::kMSDomain); + }, ort_out, nullptr, onnxruntime::kMSDomain, 1); if (!status.IsOK()) throw std::runtime_error( diff --git a/orttraining/orttraining/eager/ort_backends.cpp b/orttraining/orttraining/eager/ort_backends.cpp index f13c326972..733e1b43bb 100644 --- a/orttraining/orttraining/eager/ort_backends.cpp +++ b/orttraining/orttraining/eager/ort_backends.cpp @@ -49,8 +49,10 @@ ORTBackendsManager::ORTBackendsManager(const onnxruntime::logging::Logger& logge } } -void ORTBackendsManager::RegisterProviderLib(const std::string& provider_type, const std::string& lib_path){ - additional_provider_libs_.insert({provider_type, lib_path}); +void ORTBackendsManager::RegisterProviderLib(const std::string& provider_type, + const std::string& lib_path, + const std::string& entry_point){ + additional_provider_libs_.insert({provider_type, {lib_path, entry_point}}); } onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const std::string& provider_type, @@ -73,14 +75,16 @@ onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const st } else{ auto shared_lib_path_it = additional_provider_libs_.find(provider_type); - if (shared_lib_path_it == provider_options.end()){ + if (shared_lib_path_it == additional_provider_libs_.end()){ return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::INVALID_ARGUMENT, "Execution provider: " + provider_type + " is not supported."); } void* handle; - auto error = Env::Default().LoadDynamicLibrary(shared_lib_path_it->second, false, &handle); + auto lib_path = shared_lib_path_it->second.first; + auto entry_point = shared_lib_path_it->second.second; + auto error = Env::Default().LoadDynamicLibrary(lib_path, false, &handle); if (!error.IsOK()) { return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::INVALID_ARGUMENT, @@ -89,7 +93,7 @@ onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const st } Provider* (*PGetProvider)(); - Env::Default().GetSymbolFromLibrary(handle, "GetProvider", (void**)&PGetProvider); + ORT_RETURN_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle, entry_point, (void**)&PGetProvider)); Provider* provider = PGetProvider(); std::shared_ptr ep_factory = provider->CreateExecutionProviderFactory(&provider_options); diff --git a/orttraining/orttraining/eager/ort_backends.h b/orttraining/orttraining/eager/ort_backends.h index 99da405117..1a11458acb 100644 --- a/orttraining/orttraining/eager/ort_backends.h +++ b/orttraining/orttraining/eager/ort_backends.h @@ -19,7 +19,9 @@ class ORTBackendsManager { public: ORTBackendsManager(const onnxruntime::logging::Logger& logger); - void RegisterProviderLib(const std::string& provider_type, const std::string& lib_path); + void RegisterProviderLib(const std::string& provider_type, + const std::string& lib_path, + const std::string& entry_point); onnxruntime::Status set_device(size_t device_index, const std::string& provider_type, const onnxruntime::ProviderOptions& provider_options); @@ -32,7 +34,7 @@ private: //custom op schema registry //TODO: we might want to support load custom op schema on the fly onnxruntime::IOnnxRuntimeOpSchemaRegistryList custom_op_schema_ = {}; - std::unordered_map additional_provider_libs_ = {}; + std::unordered_map > additional_provider_libs_ = {}; }; ORTBackendsManager& GetORTBackendsManager(); diff --git a/orttraining/orttraining/eager/ort_eager.cpp b/orttraining/orttraining/eager/ort_eager.cpp index 986607648f..d4dbc18f5b 100644 --- a/orttraining/orttraining/eager/ort_eager.cpp +++ b/orttraining/orttraining/eager/ort_eager.cpp @@ -51,9 +51,14 @@ void addObjectMethodsForEager(py::module& m){ return ORTTensor_FromDLPack(dlpack_tensor); }); - m.def("_register_provider_lib", [](const std::string& name, const std::string& provider_shared_lib_path ) { - torch_ort::eager::GetORTBackendsManager().RegisterProviderLib(name, provider_shared_lib_path); - }); + m.def("_register_provider_lib", [](const std::string& name, + const std::string& provider_shared_lib_path, + const std::string& provider_factory_entry) { + torch_ort::eager::GetORTBackendsManager().RegisterProviderLib(name, provider_shared_lib_path, provider_factory_entry); + }, + py::arg("name"), + py::arg("provider_shared_lib_path"), + py::arg("provider_factory_entry") = kDefaultExecutionProviderEntry); m.def("set_device", [](size_t device_index, const std::string& provider_type, diff --git a/orttraining/orttraining/eager/test/ort_eps_test.py b/orttraining/orttraining/eager/test/ort_eps_test.py index f77872d531..f7c63f8587 100644 --- a/orttraining/orttraining/eager/test/ort_eps_test.py +++ b/orttraining/orttraining/eager/test/ort_eps_test.py @@ -13,7 +13,7 @@ class OrtEPTests(unittest.TestCase): def test_import_custom_eps(self): torch_ort.set_device(0, 'CPUExecutionProvider', {}) - torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path()) + torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), 'ProviderEntryPoint') torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'}) ort_device = torch_ort.device(1)