support register external ep lib information (#8897)

* support register external ep lib inforation; make eager mode share the same ep pools with training workloads

* fix inference code

* fix build break

* fix the message
This commit is contained in:
Tang, Cheng 2021-08-31 20:51:22 -07:00 committed by GitHub
parent 3eb08d4dc7
commit 4dc0ddf606
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 156 additions and 100 deletions

View file

@ -22,7 +22,7 @@ namespace onnxruntime {
class ORTInvoker {
public:
ORTInvoker(std::unique_ptr<IExecutionProvider> execution_provider,
ORTInvoker(std::shared_ptr<IExecutionProvider> execution_provider,
const logging::Logger& logger,
const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) :
execution_provider_(std::move(execution_provider)), logger_(logger), custom_op_registries_(custom_op_registries) {
@ -44,7 +44,7 @@ class ORTInvoker {
const int version = -1);
private:
std::unique_ptr<IExecutionProvider> execution_provider_;
std::shared_ptr<IExecutionProvider> execution_provider_;
const logging::Logger& logger_;
// custom ops for current execution provider
// we need the op schema to resolve the output type during invoke

View file

@ -2,6 +2,8 @@
// Licensed under the MIT License.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "core/providers/get_execution_providers.h"
namespace onnxruntime {
namespace python {
@ -11,6 +13,12 @@ void CreateInferencePybindStateModule(py::module& m);
PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
CreateInferencePybindStateModule(m);
// move it out of shared method since training build has a little different behavior.
m.def(
"get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableExecutionProviderNames(); },
"Return list of available Execution Providers in this installed version of Onnxruntime. "
"The order of elements represents the default priority order of Execution Providers "
"from highest to lowest.");
}
}
}

View file

@ -15,7 +15,6 @@
#include "core/framework/arena_extend_strategy.h"
#include "core/framework/data_transfer_utils.h"
#include "core/framework/data_types_internal.h"
#include "core/providers/get_execution_providers.h"
#include "core/framework/provider_options_utils.h"
#include "core/framework/random_seed.h"
#include "core/framework/sparse_tensor.h"
@ -23,6 +22,7 @@
#include "core/framework/TensorSeq.h"
#include "core/graph/graph_viewer.h"
#include "core/platform/env.h"
#include "core/providers/get_execution_providers.h"
#include "core/session/IOBinding.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
@ -338,12 +338,12 @@ const ROCMExecutionProviderInfo GetROCMExecutionProviderInfo(const ProviderOptio
#endif
std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
InferenceSession* sess,
const SessionOptions& session_options,
const std::string& type,
const ProviderOptionsMap& provider_options_map){
if (type == kCpuExecutionProvider) {
return onnxruntime::CreateExecutionProviderFactory_CPU(
sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider();
session_options.enable_cpu_mem_arena)->CreateProvider();
} else if (type == kTensorrtExecutionProvider) {
#ifdef USE_TENSORRT
std::string calibration_table, cache_path, lib_path;
@ -530,7 +530,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else if (type == kDnnlExecutionProvider) {
#ifdef USE_DNNL
return onnxruntime::CreateExecutionProviderFactory_Dnnl(
sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider();
session_options.enable_cpu_mem_arena)->CreateProvider();
#endif
} else if (type == kOpenVINOExecutionProvider) {
#ifdef USE_OPENVINO
@ -628,12 +628,12 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else if (type == kAclExecutionProvider) {
#ifdef USE_ACL
return onnxruntime::CreateExecutionProviderFactory_ACL(
sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider();
session_options.enable_cpu_mem_arena)->CreateProvider();
#endif
} else if (type == kArmNNExecutionProvider) {
#ifdef USE_ARMNN
return onnxruntime::CreateExecutionProviderFactory_ArmNN(
sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider();
session_options.enable_cpu_mem_arena)->CreateProvider();
#endif
} else if (type == kDmlExecutionProvider) {
#ifdef USE_DML
@ -655,7 +655,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
#if !defined(__ANDROID__)
LOGS_DEFAULT(WARNING) << "NNAPI execution provider can only be used to generate ORT format model in this build.";
#endif
const auto partitioning_stop_ops_list = sess->GetSessionOptions().config_options.GetConfigEntry(
const auto partitioning_stop_ops_list = session_options.config_options.GetConfigEntry(
kOrtSessionOptionsConfigNnapiEpPartitioningStopOps);
return onnxruntime::CreateExecutionProviderFactory_Nnapi(0, partitioning_stop_ops_list)->CreateProvider();
#endif
@ -705,7 +705,7 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
ORT_UNUSED_PARAMETER(provider_options_map);
for (const std::string& type : provider_types) {
auto ep = CreateExecutionProviderInstance(sess, type, provider_options_map);
auto ep = CreateExecutionProviderInstance(sess->GetSessionOptions(), type, provider_options_map);
if (ep)
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(ep)));
}
@ -833,11 +833,6 @@ void addGlobalMethods(py::module& m, Environment& env) {
"Return list of Execution Providers that this version of Onnxruntime can support. "
"The order of elements represents the default priority order of Execution Providers "
"from highest to lowest.");
m.def(
"get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableExecutionProviderNames(); },
"Return list of available Execution Providers available in this installed version of Onnxruntime. "
"The order of elements represents the default priority order of Execution Providers "
"from highest to lowest.");
m.def(
"enable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().EnableTelemetryEvents(); },
"Enables platform-specific telemetry collection where applicable.");

View file

@ -1092,7 +1092,6 @@ class TestInferenceSession(unittest.TestCase):
sess = C.InferenceSession(session_options, custom_op_model, True, True)
sess.initialize_session(['my_ep'],
[{'shared_lib_path': shared_library,
'provider_factory_entry_point' : 'ProviderEntryPoint',
'device_id':'1', 'some_config':'val'}],
set())
print("Create session with customize execution provider successfully!")

View file

@ -1 +1 @@
_ProviderEntryPoint
_GetProvider

View file

@ -64,7 +64,7 @@ struct MyEP_Provider : Provider {
extern "C" {
ORT_API(onnxruntime::Provider*, ProviderEntryPoint) {
ORT_API(onnxruntime::Provider*, GetProvider) {
return &onnxruntime::g_provider;
}

View file

@ -10,7 +10,7 @@ extern "C" {
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MyEP, _In_ OrtSessionOptions* options, int device_id);
ORT_API(onnxruntime::Provider*, ProviderEntryPoint);
ORT_API(onnxruntime::Provider*, GetProvider);
#ifdef __cplusplus
}

View file

@ -1,2 +1,2 @@
EXPORTS
ProviderEntryPoint
GetProvider

View file

@ -1,7 +1,7 @@
#_init and _fini should be local
VERS_1.0 {
global:
ProviderEntryPoint;
GetProvider;
# Hide everything else.
local:

View file

@ -15,6 +15,9 @@
namespace onnxruntime{
namespace python{
Environment& GetTrainingORTEnv();
std::shared_ptr<IExecutionProvider> GetOrCreateExecutionProvider(const std::string& provider_type,
const ProviderOptionsMap& provider_options_map,
const SessionOptions& session_options);
}
}
@ -41,53 +44,15 @@ ORTBackendsManager::ORTBackendsManager(const onnxruntime::logging::Logger& logge
}
}
void ORTBackendsManager::RegisterProviderLib(const std::string& provider_type,
const std::string& lib_path,
const std::string& entry_point){
additional_provider_libs_.insert({provider_type, {lib_path, entry_point}});
}
onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const std::string& provider_type,
const ProviderOptions& provider_options){
// query avalible device
auto& available_providers = GetAvailableExecutionProviderNames();
std::unique_ptr<IExecutionProvider> provider_p;
if (std::find(available_providers.begin(), available_providers.end(), provider_type) != available_providers.end()){
if (provider_type == kCpuExecutionProvider){
provider_p = onnxruntime::CreateExecutionProviderFactory_CPU(0)->CreateProvider();
}
}
else{
auto shared_lib_path_it = additional_provider_libs_.find(provider_type);
if (shared_lib_path_it == additional_provider_libs_.end()){
return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME,
common::StatusCode::INVALID_ARGUMENT,
"Execution provider: " + provider_type + " is not supported.");
}
void* handle;
auto lib_path = shared_lib_path_it->second.first;
auto entry_point = shared_lib_path_it->second.second;
auto error = Env::Default().LoadDynamicLibrary(lib_path, false, &handle);
if (!error.IsOK()) {
return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME,
common::StatusCode::INVALID_ARGUMENT,
"Load shared execution provider: " + provider_type + " failed: "
+ error.ErrorMessage());
}
Provider* (*PGetProvider)();
ORT_RETURN_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle, entry_point, (void**)&PGetProvider));
Provider* provider = PGetProvider();
std::shared_ptr<IExecutionProviderFactory> ep_factory = provider->CreateExecutionProviderFactory(&provider_options);
provider_p = ep_factory->CreateProvider();
}
auto ep = onnxruntime::python::GetOrCreateExecutionProvider(provider_type,
ProviderOptionsMap{{provider_type, provider_options}},
SessionOptions{});
auto invoker =
std::make_unique<onnxruntime::ORTInvoker>(
std::move(provider_p),
std::move(ep),
logger_,
custom_op_schema_);

View file

@ -19,10 +19,6 @@ class ORTBackendsManager {
public:
ORTBackendsManager(const onnxruntime::logging::Logger& logger);
void RegisterProviderLib(const std::string& provider_type,
const std::string& lib_path,
const std::string& entry_point);
onnxruntime::Status set_device(size_t device_index, const std::string& provider_type,
const onnxruntime::ProviderOptions& provider_options);
@ -34,7 +30,6 @@ private:
//custom op schema registry
//TODO: we might want to support load custom op schema on the fly
onnxruntime::IOnnxRuntimeOpSchemaRegistryList custom_op_schema_ = {};
std::unordered_map<std::string, std::pair<std::string, std::string> > additional_provider_libs_ = {};
};
ORTBackendsManager& GetORTBackendsManager();

View file

@ -52,15 +52,6 @@ void addObjectMethodsForEager(py::module& m){
return ORTTensor_FromDLPack(dlpack_tensor);
});
m.def("_register_provider_lib", [](const std::string& name,
const std::string& provider_shared_lib_path,
const std::string& provider_factory_entry) {
torch_ort::eager::GetORTBackendsManager().RegisterProviderLib(name, provider_shared_lib_path, provider_factory_entry);
},
py::arg("name"),
py::arg("provider_shared_lib_path"),
py::arg("provider_factory_entry") = kDefaultExecutionProviderEntry);
m.def("set_device", [](size_t device_index,
const std::string& provider_type,
const std::unordered_map<std::string, std::string>& arguments){

View file

@ -13,7 +13,7 @@ class OrtEPTests(unittest.TestCase):
def test_import_custom_eps(self):
torch_ort.set_device(0, 'CPUExecutionProvider', {})
torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), 'ProviderEntryPoint')
torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), {})
torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
ort_device = torch_ort.device(1)

View file

@ -7,6 +7,7 @@
#include "core/common/logging/logging.h"
#include "core/common/logging/severity.h"
#include "core/providers/get_execution_providers.h"
#include "core/platform/env.h"
#include "core/session/provider_bridge_ort.h"
@ -20,7 +21,7 @@ namespace py = pybind11;
using namespace onnxruntime::logging;
std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
InferenceSession* sess,
const SessionOptions& session_options,
const std::string& type,
const ProviderOptionsMap& provider_options_map);
@ -41,6 +42,7 @@ void addObjectMethodsForEager(py::module& m);
void InitArray();
using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider> >;
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions> > ;
bool GetDyanmicExecutionProviderHash(
const std::string& ep_shared_lib_path,
@ -137,6 +139,8 @@ public:
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
&SessionObjectInitializer::default_logger_id),
ort_env_));
auto& builtinEPs = GetAvailableExecutionProviderNames();
available_training_eps_.assign(builtinEPs.begin(), builtinEPs.end());
}
Environment& GetORTEnv(){
@ -156,6 +160,20 @@ public:
std::move(execution_provider)});
}
void RegisterExtExecutionProviderInfo(const std::string& provider_type,
const std::string& provider_lib_path,
const ProviderOptions& default_options){
ext_execution_provider_info_map_.insert({provider_type, {provider_lib_path, default_options}});
if (std::find(available_training_eps_.begin(), available_training_eps_.end(), provider_type) == available_training_eps_.end())
available_training_eps_.push_back(provider_type);
}
const std::vector<std::string>& GetAvailableTrainingExecutionProviderTypes(){
return available_training_eps_;
}
ExecutionProviderLibInfoMap ext_execution_provider_info_map_;
private:
std::string GetExecutionProviderMapKey(const std::string& provider_type,
size_t hash){
@ -166,6 +184,7 @@ private:
std::unique_ptr<Environment> ort_env_;
ExecutionProviderMap execution_provider_instances_map_;
std::vector<std::string> available_training_eps_;
};
static std::unique_ptr<ORTTrainingPythonEnv> ort_training_env;
@ -199,30 +218,72 @@ Environment& GetTrainingORTEnv() {
return ort_training_env->GetORTEnv();
}
void ResolveExtraProviderOptions(const std::string& provider_type,
const ProviderOptionsMap& original_provider_options_map,
ProviderOptionsMap& merged_options){
auto& training_env = GetTrainingEnv();
auto it = training_env.ext_execution_provider_info_map_.find(provider_type);
if (it == training_env.ext_execution_provider_info_map_.end()){
//nothing changed.
merged_options = original_provider_options_map;
}else{
ProviderOptions options = it->second.second;
options.insert({kExecutionProviderSharedLibraryPath, it->second.first});
auto original_map_it = original_provider_options_map.find(provider_type);
if (original_map_it != original_provider_options_map.end()){
for (auto [k, v] : original_map_it->second){
options.insert({k, v});
}
}
merged_options[provider_type] = options;
}
}
std::unique_ptr<IExecutionProvider> CreateTrainingEP(
const SessionOptions& session_options,
const std::string& provider_type,
const ProviderOptionsMap& provider_options_map){
ORTTrainingPythonEnv& training_env = GetTrainingEnv();
if (training_env.ext_execution_provider_info_map_.find(provider_type) !=
training_env.ext_execution_provider_info_map_.end()){
ProviderOptionsMap merged_options;
ResolveExtraProviderOptions(provider_type, provider_options_map, merged_options);
return CreateExecutionProviderInstance(session_options, provider_type, merged_options);
}else{
return CreateExecutionProviderInstance(session_options, provider_type, provider_options_map);
}
}
std::shared_ptr<IExecutionProvider> GetOrCreateExecutionProvider(const std::string& provider_type,
const ProviderOptionsMap& provider_options_map,
const SessionOptions& session_options){
ORTTrainingPythonEnv& training_env = GetTrainingEnv();
// search in environment
size_t hash;
if (GetProviderInstanceHash(provider_type, provider_options_map, hash)){
auto cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
if (!cached_provider_instance){
auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map);
if (ep){
training_env.AddExecutionProvider(provider_type, hash, std::move(ep));
cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
}
}
return cached_provider_instance;
}
else{
// the EP doesn't support cache, register the instance to session
auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map);
return ep;
}
}
void ORTTrainingRegisterExecutionProviders(InferenceSession* sess, const std::vector<std::string>& provider_types,
const ProviderOptionsMap& provider_options_map) {
// search in environment
ORTTrainingPythonEnv& training_env = GetTrainingEnv();
for (auto provider_type : provider_types){
size_t hash;
if (GetProviderInstanceHash(provider_type, provider_options_map, hash)){
auto cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
if (!cached_provider_instance){
auto ep = CreateExecutionProviderInstance(sess, provider_type, provider_options_map);
if (ep){
training_env.AddExecutionProvider(provider_type, hash, std::move(ep));
cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
}
}
if (cached_provider_instance)
OrtPybindThrowIfError(sess->RegisterExecutionProvider(cached_provider_instance));
}
else{
// the EP doesn't support cache, register the instance to session
auto ep = CreateExecutionProviderInstance(sess, provider_type, provider_options_map);
if (ep)
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(ep)));
}
auto ep = GetOrCreateExecutionProvider(provider_type, provider_options_map, sess->GetSessionOptions());
if (ep)
OrtPybindThrowIfError(sess->RegisterExecutionProvider(ep));
}
}
@ -250,6 +311,19 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
#ifdef ENABLE_EAGER_MODE
addObjectMethodsForEager(m);
#endif
m.def("_register_provider_lib", [](const std::string& name,
const std::string& provider_shared_lib_path,
const ProviderOptions& default_options) {
GetTrainingEnv().RegisterExtExecutionProviderInfo(name, provider_shared_lib_path, default_options);
});
m.def(
"get_available_providers", []() -> const std::vector<std::string>& {
return GetTrainingEnv().GetAvailableTrainingExecutionProviderTypes(); },
"Return list of available Execution Providers in this installed version of Onnxruntime. "
"The order of elements represents the default priority order of Execution Providers "
"from highest to lowest.");
// clean the ort training environment when python interpreter exit
// otherwise the global var will be de-constrcut after user main.

View file

@ -0,0 +1,29 @@
import unittest
import onnxruntime_pybind11_state as C
import os
class EPRegistrationTests(unittest.TestCase):
def get_test_execution_provider_path(self):
return os.path.join('.', 'libtest_execution_provider.so')
def test_register_custom_eps(self):
C._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), {'some_config':'val'})
assert 'TestExecutionProvider' in C.get_available_providers()
this = os.path.dirname(__file__)
custom_op_model = os.path.join(this, "testdata", "custom_execution_provider_library", "test_model.onnx")
if not os.path.exists(custom_op_model):
raise FileNotFoundError("Unable to find '{0}'".format(custom_op_model))
session_options = C.get_default_session_options()
sess = C.InferenceSession(session_options, custom_op_model, True, True)
sess.initialize_session(['TestExecutionProvider'],
[{'device_id':'0'}],
set())
print("Created session with customize execution provider successfully!")
if __name__ == '__main__':
unittest.main()