[eager mode] fix build and support customize shared provider entry point (#8680)

* fix build break

* support customize the name of shared provide lib's entry point

* fix non training build

* check error code

* check return code
This commit is contained in:
Tang, Cheng 2021-08-11 15:10:35 -07:00 committed by GitHub
parent f661c18654
commit de2a53e46d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 48 additions and 22 deletions

View file

@ -7,9 +7,12 @@
#include "core/graph/model.h"
#include "core/framework/op_kernel.h"
#include "core/session/ort_env.h"
#include "core/graph/constants.h"
namespace onnxruntime {
#define ORT_EAGER_ONNX_OPSET_VERSION 14
common::Status ORTInvoker::Invoke(const std::string& op_name,
//optional inputs / outputs?
const std::vector<OrtValue>& inputs,
@ -17,13 +20,15 @@ common::Status ORTInvoker::Invoke(const std::string& op_name,
const NodeAttributes* attributes,
const std::string& domain,
const int version) {
std::unordered_map<std::string, int> domain_version_map = {{kOnnxDomain, ORT_EAGER_ONNX_OPSET_VERSION},
{kMSDomain, 1}};
//create a graph
Model model("test",
false,
ModelMetaData(),
ORT_TSTR(""),
custom_op_registries_,
{},
domain_version_map,
{},
logger_);

View file

@ -46,6 +46,8 @@ const OrtDevice::DeviceType OrtDevice::GPU;
namespace onnxruntime {
constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path";
constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point";
} // namespace onnxruntime
#if defined(_MSC_VER)
@ -299,7 +301,8 @@ static inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime
static std::unique_ptr<onnxruntime::IExecutionProvider> LoadExecutionProvider(
const std::string& ep_shared_lib_path,
const ProviderOptions& provider_options = {}) {
const ProviderOptions& provider_options = {},
const std::string& entry_symbol_name = "GetProvider") {
void* handle;
auto error = Env::Default().LoadDynamicLibrary(ep_shared_lib_path, false, &handle);
if (!error.IsOK()) {
@ -307,7 +310,7 @@ static std::unique_ptr<onnxruntime::IExecutionProvider> LoadExecutionProvider(
}
Provider* (*PGetProvider)();
Env::Default().GetSymbolFromLibrary(handle, "GetProvider", (void**)&PGetProvider);
OrtPybindThrowIfError(Env::Default().GetSymbolFromLibrary(handle, entry_symbol_name, (void**)&PGetProvider));
Provider* provider = PGetProvider();
std::shared_ptr<IExecutionProviderFactory> ep_factory = provider->CreateExecutionProviderFactory(&provider_options);
@ -681,11 +684,16 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
// this is an EP with dynamic loading
// construct the provider option
ProviderOptions provider_options;
std::string entry_symbol = kDefaultExecutionProviderEntry;
for (auto option : it->second) {
if (option.first != kExecutionProviderSharedLibraryPath)
if (option.first == kExecutionProviderSharedLibraryEntry){
entry_symbol = option.second;
}
else if (option.first != kExecutionProviderSharedLibraryPath){
provider_options.insert(option);
}
}
auto p_ep = LoadExecutionProvider(shared_lib_path_it->second, provider_options);
auto p_ep = LoadExecutionProvider(shared_lib_path_it->second, provider_options, entry_symbol);
ORT_THROW_IF_ERROR(sess->RegisterExecutionProvider(
std::move(p_ep)));
continue;

View file

@ -454,4 +454,5 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rknpu();
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_CoreML(uint32_t flags);
constexpr const char* kDefaultExecutionProviderEntry = "GetProvider";
} // namespace onnxruntime

View file

@ -1092,6 +1092,7 @@ class TestInferenceSession(unittest.TestCase):
sess = C.InferenceSession(session_options, custom_op_model, True, True)
sess.initialize_session(['my_ep'],
[{'shared_lib_path': shared_library,
'provider_factory_entry_point' : 'ProviderEntryPoint',
'device_id':'1', 'some_config':'val'}],
set())
print("Create session with customize execution provider successfully!")

View file

@ -1 +1 @@
_GetProvider
_ProviderEntryPoint

View file

@ -64,7 +64,7 @@ struct MyEP_Provider : Provider {
extern "C" {
ORT_API(onnxruntime::Provider*, GetProvider) {
ORT_API(onnxruntime::Provider*, ProviderEntryPoint) {
return &onnxruntime::g_provider;
}

View file

@ -10,7 +10,7 @@ extern "C" {
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MyEP, _In_ OrtSessionOptions* options, int device_id);
ORT_API(onnxruntime::Provider*, GetProvider);
ORT_API(onnxruntime::Provider*, ProviderEntryPoint);
#ifdef __cplusplus
}

View file

@ -1,2 +1,2 @@
EXPORTS
GetProvider
ProviderEntryPoint

View file

@ -1,7 +1,7 @@
#_init and _fini should be local
VERS_1.0 {
global:
GetProvider;
ProviderEntryPoint;
# Hide everything else.
local:

View file

@ -273,7 +273,7 @@ at::Tensor& zero_(at::Tensor& self){
"ZeroGradient", {
std::move(ort_in_self),
std::move(flag_val)
}, ort_out, nullptr, onnxruntime::kMSDomain);
}, ort_out, nullptr, onnxruntime::kMSDomain, 1);
if (!status.IsOK())
throw std::runtime_error(

View file

@ -49,8 +49,10 @@ ORTBackendsManager::ORTBackendsManager(const onnxruntime::logging::Logger& logge
}
}
void ORTBackendsManager::RegisterProviderLib(const std::string& provider_type, const std::string& lib_path){
additional_provider_libs_.insert({provider_type, lib_path});
void ORTBackendsManager::RegisterProviderLib(const std::string& provider_type,
const std::string& lib_path,
const std::string& entry_point){
additional_provider_libs_.insert({provider_type, {lib_path, entry_point}});
}
onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const std::string& provider_type,
@ -73,14 +75,16 @@ onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const st
}
else{
auto shared_lib_path_it = additional_provider_libs_.find(provider_type);
if (shared_lib_path_it == provider_options.end()){
if (shared_lib_path_it == additional_provider_libs_.end()){
return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME,
common::StatusCode::INVALID_ARGUMENT,
"Execution provider: " + provider_type + " is not supported.");
}
void* handle;
auto error = Env::Default().LoadDynamicLibrary(shared_lib_path_it->second, false, &handle);
auto lib_path = shared_lib_path_it->second.first;
auto entry_point = shared_lib_path_it->second.second;
auto error = Env::Default().LoadDynamicLibrary(lib_path, false, &handle);
if (!error.IsOK()) {
return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME,
common::StatusCode::INVALID_ARGUMENT,
@ -89,7 +93,7 @@ onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const st
}
Provider* (*PGetProvider)();
Env::Default().GetSymbolFromLibrary(handle, "GetProvider", (void**)&PGetProvider);
ORT_RETURN_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle, entry_point, (void**)&PGetProvider));
Provider* provider = PGetProvider();
std::shared_ptr<IExecutionProviderFactory> ep_factory = provider->CreateExecutionProviderFactory(&provider_options);

View file

@ -19,7 +19,9 @@ class ORTBackendsManager {
public:
ORTBackendsManager(const onnxruntime::logging::Logger& logger);
void RegisterProviderLib(const std::string& provider_type, const std::string& lib_path);
void RegisterProviderLib(const std::string& provider_type,
const std::string& lib_path,
const std::string& entry_point);
onnxruntime::Status set_device(size_t device_index, const std::string& provider_type,
const onnxruntime::ProviderOptions& provider_options);
@ -32,7 +34,7 @@ private:
//custom op schema registry
//TODO: we might want to support load custom op schema on the fly
onnxruntime::IOnnxRuntimeOpSchemaRegistryList custom_op_schema_ = {};
std::unordered_map<std::string, std::string> additional_provider_libs_ = {};
std::unordered_map<std::string, std::pair<std::string, std::string> > additional_provider_libs_ = {};
};
ORTBackendsManager& GetORTBackendsManager();

View file

@ -51,9 +51,14 @@ void addObjectMethodsForEager(py::module& m){
return ORTTensor_FromDLPack(dlpack_tensor);
});
m.def("_register_provider_lib", [](const std::string& name, const std::string& provider_shared_lib_path ) {
torch_ort::eager::GetORTBackendsManager().RegisterProviderLib(name, provider_shared_lib_path);
});
m.def("_register_provider_lib", [](const std::string& name,
const std::string& provider_shared_lib_path,
const std::string& provider_factory_entry) {
torch_ort::eager::GetORTBackendsManager().RegisterProviderLib(name, provider_shared_lib_path, provider_factory_entry);
},
py::arg("name"),
py::arg("provider_shared_lib_path"),
py::arg("provider_factory_entry") = kDefaultExecutionProviderEntry);
m.def("set_device", [](size_t device_index,
const std::string& provider_type,

View file

@ -13,7 +13,7 @@ class OrtEPTests(unittest.TestCase):
def test_import_custom_eps(self):
torch_ort.set_device(0, 'CPUExecutionProvider', {})
torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path())
torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), 'ProviderEntryPoint')
torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
ort_device = torch_ort.device(1)