mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
[eager mode] fix build and support customize shared provider entry point (#8680)
* fix build break * support customize the name of shared provide lib's entry point * fix non training build * check error code * check return code
This commit is contained in:
parent
f661c18654
commit
de2a53e46d
14 changed files with 48 additions and 22 deletions
|
|
@ -7,9 +7,12 @@
|
|||
#include "core/graph/model.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/session/ort_env.h"
|
||||
#include "core/graph/constants.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
#define ORT_EAGER_ONNX_OPSET_VERSION 14
|
||||
|
||||
common::Status ORTInvoker::Invoke(const std::string& op_name,
|
||||
//optional inputs / outputs?
|
||||
const std::vector<OrtValue>& inputs,
|
||||
|
|
@ -17,13 +20,15 @@ common::Status ORTInvoker::Invoke(const std::string& op_name,
|
|||
const NodeAttributes* attributes,
|
||||
const std::string& domain,
|
||||
const int version) {
|
||||
std::unordered_map<std::string, int> domain_version_map = {{kOnnxDomain, ORT_EAGER_ONNX_OPSET_VERSION},
|
||||
{kMSDomain, 1}};
|
||||
//create a graph
|
||||
Model model("test",
|
||||
false,
|
||||
ModelMetaData(),
|
||||
ORT_TSTR(""),
|
||||
custom_op_registries_,
|
||||
{},
|
||||
domain_version_map,
|
||||
{},
|
||||
logger_);
|
||||
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ const OrtDevice::DeviceType OrtDevice::GPU;
|
|||
namespace onnxruntime {
|
||||
|
||||
constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path";
|
||||
constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point";
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
|
@ -299,7 +301,8 @@ static inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime
|
|||
|
||||
static std::unique_ptr<onnxruntime::IExecutionProvider> LoadExecutionProvider(
|
||||
const std::string& ep_shared_lib_path,
|
||||
const ProviderOptions& provider_options = {}) {
|
||||
const ProviderOptions& provider_options = {},
|
||||
const std::string& entry_symbol_name = "GetProvider") {
|
||||
void* handle;
|
||||
auto error = Env::Default().LoadDynamicLibrary(ep_shared_lib_path, false, &handle);
|
||||
if (!error.IsOK()) {
|
||||
|
|
@ -307,7 +310,7 @@ static std::unique_ptr<onnxruntime::IExecutionProvider> LoadExecutionProvider(
|
|||
}
|
||||
|
||||
Provider* (*PGetProvider)();
|
||||
Env::Default().GetSymbolFromLibrary(handle, "GetProvider", (void**)&PGetProvider);
|
||||
OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, entry_symbol_name, (void**)&PGetProvider));
|
||||
|
||||
Provider* provider = PGetProvider();
|
||||
std::shared_ptr<IExecutionProviderFactory> ep_factory = provider->CreateExecutionProviderFactory(&provider_options);
|
||||
|
|
@ -681,11 +684,16 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
|
|||
// this is an EP with dynamic loading
|
||||
// construct the provider option
|
||||
ProviderOptions provider_options;
|
||||
std::string entry_symbol = kDefaultExecutionProviderEntry;
|
||||
for (auto option : it->second) {
|
||||
if (option.first != kExecutionProviderSharedLibraryPath)
|
||||
if (option.first == kExecutionProviderSharedLibraryEntry){
|
||||
entry_symbol = option.second;
|
||||
}
|
||||
else if (option.first != kExecutionProviderSharedLibraryPath){
|
||||
provider_options.insert(option);
|
||||
}
|
||||
}
|
||||
auto p_ep = LoadExecutionProvider(shared_lib_path_it->second, provider_options);
|
||||
auto p_ep = LoadExecutionProvider(shared_lib_path_it->second, provider_options, entry_symbol);
|
||||
ORT_THROW_IF_ERROR(sess->RegisterExecutionProvider(
|
||||
std::move(p_ep)));
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -454,4 +454,5 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(
|
|||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rknpu();
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CoreML(uint32_t flags);
|
||||
|
||||
constexpr const char* kDefaultExecutionProviderEntry = "GetProvider";
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1092,6 +1092,7 @@ class TestInferenceSession(unittest.TestCase):
|
|||
sess = C.InferenceSession(session_options, custom_op_model, True, True)
|
||||
sess.initialize_session(['my_ep'],
|
||||
[{'shared_lib_path': shared_library,
|
||||
'provider_factory_entry_point' : 'ProviderEntryPoint',
|
||||
'device_id':'1', 'some_config':'val'}],
|
||||
set())
|
||||
print("Create session with customize execution provider successfully!")
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
_GetProvider
|
||||
_ProviderEntryPoint
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ struct MyEP_Provider : Provider {
|
|||
|
||||
extern "C" {
|
||||
|
||||
ORT_API(onnxruntime::Provider*, GetProvider) {
|
||||
ORT_API(onnxruntime::Provider*, ProviderEntryPoint) {
|
||||
return &onnxruntime::g_provider;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ extern "C" {
|
|||
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MyEP, _In_ OrtSessionOptions* options, int device_id);
|
||||
|
||||
ORT_API(onnxruntime::Provider*, GetProvider);
|
||||
ORT_API(onnxruntime::Provider*, ProviderEntryPoint);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
EXPORTS
|
||||
GetProvider
|
||||
ProviderEntryPoint
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#_init and _fini should be local
|
||||
VERS_1.0 {
|
||||
global:
|
||||
GetProvider;
|
||||
ProviderEntryPoint;
|
||||
|
||||
# Hide everything else.
|
||||
local:
|
||||
|
|
|
|||
|
|
@ -273,7 +273,7 @@ at::Tensor& zero_(at::Tensor& self){
|
|||
"ZeroGradient", {
|
||||
std::move(ort_in_self),
|
||||
std::move(flag_val)
|
||||
}, ort_out, nullptr, onnxruntime::kMSDomain);
|
||||
}, ort_out, nullptr, onnxruntime::kMSDomain, 1);
|
||||
|
||||
if (!status.IsOK())
|
||||
throw std::runtime_error(
|
||||
|
|
|
|||
|
|
@ -49,8 +49,10 @@ ORTBackendsManager::ORTBackendsManager(const onnxruntime::logging::Logger& logge
|
|||
}
|
||||
}
|
||||
|
||||
void ORTBackendsManager::RegisterProviderLib(const std::string& provider_type, const std::string& lib_path){
|
||||
additional_provider_libs_.insert({provider_type, lib_path});
|
||||
void ORTBackendsManager::RegisterProviderLib(const std::string& provider_type,
|
||||
const std::string& lib_path,
|
||||
const std::string& entry_point){
|
||||
additional_provider_libs_.insert({provider_type, {lib_path, entry_point}});
|
||||
}
|
||||
|
||||
onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const std::string& provider_type,
|
||||
|
|
@ -73,14 +75,16 @@ onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const st
|
|||
}
|
||||
else{
|
||||
auto shared_lib_path_it = additional_provider_libs_.find(provider_type);
|
||||
if (shared_lib_path_it == provider_options.end()){
|
||||
if (shared_lib_path_it == additional_provider_libs_.end()){
|
||||
return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME,
|
||||
common::StatusCode::INVALID_ARGUMENT,
|
||||
"Execution provider: " + provider_type + " is not supported.");
|
||||
}
|
||||
|
||||
void* handle;
|
||||
auto error = Env::Default().LoadDynamicLibrary(shared_lib_path_it->second, false, &handle);
|
||||
auto lib_path = shared_lib_path_it->second.first;
|
||||
auto entry_point = shared_lib_path_it->second.second;
|
||||
auto error = Env::Default().LoadDynamicLibrary(lib_path, false, &handle);
|
||||
if (!error.IsOK()) {
|
||||
return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME,
|
||||
common::StatusCode::INVALID_ARGUMENT,
|
||||
|
|
@ -89,7 +93,7 @@ onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const st
|
|||
}
|
||||
|
||||
Provider* (*PGetProvider)();
|
||||
Env::Default().GetSymbolFromLibrary(handle, "GetProvider", (void**)&PGetProvider);
|
||||
ORT_RETURN_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle, entry_point, (void**)&PGetProvider));
|
||||
|
||||
Provider* provider = PGetProvider();
|
||||
std::shared_ptr<IExecutionProviderFactory> ep_factory = provider->CreateExecutionProviderFactory(&provider_options);
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ class ORTBackendsManager {
|
|||
public:
|
||||
ORTBackendsManager(const onnxruntime::logging::Logger& logger);
|
||||
|
||||
void RegisterProviderLib(const std::string& provider_type, const std::string& lib_path);
|
||||
void RegisterProviderLib(const std::string& provider_type,
|
||||
const std::string& lib_path,
|
||||
const std::string& entry_point);
|
||||
|
||||
onnxruntime::Status set_device(size_t device_index, const std::string& provider_type,
|
||||
const onnxruntime::ProviderOptions& provider_options);
|
||||
|
|
@ -32,7 +34,7 @@ private:
|
|||
//custom op schema registry
|
||||
//TODO: we might want to support load custom op schema on the fly
|
||||
onnxruntime::IOnnxRuntimeOpSchemaRegistryList custom_op_schema_ = {};
|
||||
std::unordered_map<std::string, std::string> additional_provider_libs_ = {};
|
||||
std::unordered_map<std::string, std::pair<std::string, std::string> > additional_provider_libs_ = {};
|
||||
};
|
||||
|
||||
ORTBackendsManager& GetORTBackendsManager();
|
||||
|
|
|
|||
|
|
@ -51,9 +51,14 @@ void addObjectMethodsForEager(py::module& m){
|
|||
return ORTTensor_FromDLPack(dlpack_tensor);
|
||||
});
|
||||
|
||||
m.def("_register_provider_lib", [](const std::string& name, const std::string& provider_shared_lib_path ) {
|
||||
torch_ort::eager::GetORTBackendsManager().RegisterProviderLib(name, provider_shared_lib_path);
|
||||
});
|
||||
m.def("_register_provider_lib", [](const std::string& name,
|
||||
const std::string& provider_shared_lib_path,
|
||||
const std::string& provider_factory_entry) {
|
||||
torch_ort::eager::GetORTBackendsManager().RegisterProviderLib(name, provider_shared_lib_path, provider_factory_entry);
|
||||
},
|
||||
py::arg("name"),
|
||||
py::arg("provider_shared_lib_path"),
|
||||
py::arg("provider_factory_entry") = kDefaultExecutionProviderEntry);
|
||||
|
||||
m.def("set_device", [](size_t device_index,
|
||||
const std::string& provider_type,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class OrtEPTests(unittest.TestCase):
|
|||
def test_import_custom_eps(self):
|
||||
torch_ort.set_device(0, 'CPUExecutionProvider', {})
|
||||
|
||||
torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path())
|
||||
torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), 'ProviderEntryPoint')
|
||||
torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
|
||||
ort_device = torch_ort.device(1)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue