mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-19 02:03:52 +00:00
support register external ep lib information (#8897)
* support register external ep lib inforation; make eager mode share the same ep pools with training workloads * fix inference code * fix build break * fix the message
This commit is contained in:
parent
3eb08d4dc7
commit
4dc0ddf606
15 changed files with 156 additions and 100 deletions
|
|
@ -22,7 +22,7 @@ namespace onnxruntime {
|
|||
|
||||
class ORTInvoker {
|
||||
public:
|
||||
ORTInvoker(std::unique_ptr<IExecutionProvider> execution_provider,
|
||||
ORTInvoker(std::shared_ptr<IExecutionProvider> execution_provider,
|
||||
const logging::Logger& logger,
|
||||
const IOnnxRuntimeOpSchemaRegistryList& custom_op_registries) :
|
||||
execution_provider_(std::move(execution_provider)), logger_(logger), custom_op_registries_(custom_op_registries) {
|
||||
|
|
@ -44,7 +44,7 @@ class ORTInvoker {
|
|||
const int version = -1);
|
||||
|
||||
private:
|
||||
std::unique_ptr<IExecutionProvider> execution_provider_;
|
||||
std::shared_ptr<IExecutionProvider> execution_provider_;
|
||||
const logging::Logger& logger_;
|
||||
// custom ops for current execution provider
|
||||
// we need the op schema to resolve the output type during invoke
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include "core/providers/get_execution_providers.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace python {
|
||||
|
|
@ -11,6 +13,12 @@ void CreateInferencePybindStateModule(py::module& m);
|
|||
|
||||
PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
|
||||
CreateInferencePybindStateModule(m);
|
||||
// move it out of shared method since training build has a little different behavior.
|
||||
m.def(
|
||||
"get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableExecutionProviderNames(); },
|
||||
"Return list of available Execution Providers in this installed version of Onnxruntime. "
|
||||
"The order of elements represents the default priority order of Execution Providers "
|
||||
"from highest to lowest.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -15,7 +15,6 @@
|
|||
#include "core/framework/arena_extend_strategy.h"
|
||||
#include "core/framework/data_transfer_utils.h"
|
||||
#include "core/framework/data_types_internal.h"
|
||||
#include "core/providers/get_execution_providers.h"
|
||||
#include "core/framework/provider_options_utils.h"
|
||||
#include "core/framework/random_seed.h"
|
||||
#include "core/framework/sparse_tensor.h"
|
||||
|
|
@ -23,6 +22,7 @@
|
|||
#include "core/framework/TensorSeq.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/providers/get_execution_providers.h"
|
||||
#include "core/session/IOBinding.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
|
|
@ -338,12 +338,12 @@ const ROCMExecutionProviderInfo GetROCMExecutionProviderInfo(const ProviderOptio
|
|||
#endif
|
||||
|
||||
std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
||||
InferenceSession* sess,
|
||||
const SessionOptions& session_options,
|
||||
const std::string& type,
|
||||
const ProviderOptionsMap& provider_options_map){
|
||||
if (type == kCpuExecutionProvider) {
|
||||
return onnxruntime::CreateExecutionProviderFactory_CPU(
|
||||
sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider();
|
||||
session_options.enable_cpu_mem_arena)->CreateProvider();
|
||||
} else if (type == kTensorrtExecutionProvider) {
|
||||
#ifdef USE_TENSORRT
|
||||
std::string calibration_table, cache_path, lib_path;
|
||||
|
|
@ -530,7 +530,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
} else if (type == kDnnlExecutionProvider) {
|
||||
#ifdef USE_DNNL
|
||||
return onnxruntime::CreateExecutionProviderFactory_Dnnl(
|
||||
sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider();
|
||||
session_options.enable_cpu_mem_arena)->CreateProvider();
|
||||
#endif
|
||||
} else if (type == kOpenVINOExecutionProvider) {
|
||||
#ifdef USE_OPENVINO
|
||||
|
|
@ -628,12 +628,12 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
} else if (type == kAclExecutionProvider) {
|
||||
#ifdef USE_ACL
|
||||
return onnxruntime::CreateExecutionProviderFactory_ACL(
|
||||
sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider();
|
||||
session_options.enable_cpu_mem_arena)->CreateProvider();
|
||||
#endif
|
||||
} else if (type == kArmNNExecutionProvider) {
|
||||
#ifdef USE_ARMNN
|
||||
return onnxruntime::CreateExecutionProviderFactory_ArmNN(
|
||||
sess->GetSessionOptions().enable_cpu_mem_arena)->CreateProvider();
|
||||
session_options.enable_cpu_mem_arena)->CreateProvider();
|
||||
#endif
|
||||
} else if (type == kDmlExecutionProvider) {
|
||||
#ifdef USE_DML
|
||||
|
|
@ -655,7 +655,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
|||
#if !defined(__ANDROID__)
|
||||
LOGS_DEFAULT(WARNING) << "NNAPI execution provider can only be used to generate ORT format model in this build.";
|
||||
#endif
|
||||
const auto partitioning_stop_ops_list = sess->GetSessionOptions().config_options.GetConfigEntry(
|
||||
const auto partitioning_stop_ops_list = session_options.config_options.GetConfigEntry(
|
||||
kOrtSessionOptionsConfigNnapiEpPartitioningStopOps);
|
||||
return onnxruntime::CreateExecutionProviderFactory_Nnapi(0, partitioning_stop_ops_list)->CreateProvider();
|
||||
#endif
|
||||
|
|
@ -705,7 +705,7 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
|
|||
ORT_UNUSED_PARAMETER(provider_options_map);
|
||||
|
||||
for (const std::string& type : provider_types) {
|
||||
auto ep = CreateExecutionProviderInstance(sess, type, provider_options_map);
|
||||
auto ep = CreateExecutionProviderInstance(sess->GetSessionOptions(), type, provider_options_map);
|
||||
if (ep)
|
||||
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(ep)));
|
||||
}
|
||||
|
|
@ -833,11 +833,6 @@ void addGlobalMethods(py::module& m, Environment& env) {
|
|||
"Return list of Execution Providers that this version of Onnxruntime can support. "
|
||||
"The order of elements represents the default priority order of Execution Providers "
|
||||
"from highest to lowest.");
|
||||
m.def(
|
||||
"get_available_providers", []() -> const std::vector<std::string>& { return GetAvailableExecutionProviderNames(); },
|
||||
"Return list of available Execution Providers available in this installed version of Onnxruntime. "
|
||||
"The order of elements represents the default priority order of Execution Providers "
|
||||
"from highest to lowest.");
|
||||
m.def(
|
||||
"enable_telemetry_events", []() -> void { platform_env.GetTelemetryProvider().EnableTelemetryEvents(); },
|
||||
"Enables platform-specific telemetry collection where applicable.");
|
||||
|
|
|
|||
|
|
@ -1092,7 +1092,6 @@ class TestInferenceSession(unittest.TestCase):
|
|||
sess = C.InferenceSession(session_options, custom_op_model, True, True)
|
||||
sess.initialize_session(['my_ep'],
|
||||
[{'shared_lib_path': shared_library,
|
||||
'provider_factory_entry_point' : 'ProviderEntryPoint',
|
||||
'device_id':'1', 'some_config':'val'}],
|
||||
set())
|
||||
print("Create session with customize execution provider successfully!")
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
_ProviderEntryPoint
|
||||
_GetProvider
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ struct MyEP_Provider : Provider {
|
|||
|
||||
extern "C" {
|
||||
|
||||
ORT_API(onnxruntime::Provider*, ProviderEntryPoint) {
|
||||
ORT_API(onnxruntime::Provider*, GetProvider) {
|
||||
return &onnxruntime::g_provider;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ extern "C" {
|
|||
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MyEP, _In_ OrtSessionOptions* options, int device_id);
|
||||
|
||||
ORT_API(onnxruntime::Provider*, ProviderEntryPoint);
|
||||
ORT_API(onnxruntime::Provider*, GetProvider);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
EXPORTS
|
||||
ProviderEntryPoint
|
||||
GetProvider
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#_init and _fini should be local
|
||||
VERS_1.0 {
|
||||
global:
|
||||
ProviderEntryPoint;
|
||||
GetProvider;
|
||||
|
||||
# Hide everything else.
|
||||
local:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@
|
|||
namespace onnxruntime{
|
||||
namespace python{
|
||||
Environment& GetTrainingORTEnv();
|
||||
std::shared_ptr<IExecutionProvider> GetOrCreateExecutionProvider(const std::string& provider_type,
|
||||
const ProviderOptionsMap& provider_options_map,
|
||||
const SessionOptions& session_options);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -41,53 +44,15 @@ ORTBackendsManager::ORTBackendsManager(const onnxruntime::logging::Logger& logge
|
|||
}
|
||||
}
|
||||
|
||||
void ORTBackendsManager::RegisterProviderLib(const std::string& provider_type,
|
||||
const std::string& lib_path,
|
||||
const std::string& entry_point){
|
||||
additional_provider_libs_.insert({provider_type, {lib_path, entry_point}});
|
||||
}
|
||||
|
||||
onnxruntime::Status ORTBackendsManager::set_device(size_t device_index, const std::string& provider_type,
|
||||
const ProviderOptions& provider_options){
|
||||
// query avalible device
|
||||
auto& available_providers = GetAvailableExecutionProviderNames();
|
||||
std::unique_ptr<IExecutionProvider> provider_p;
|
||||
if (std::find(available_providers.begin(), available_providers.end(), provider_type) != available_providers.end()){
|
||||
if (provider_type == kCpuExecutionProvider){
|
||||
provider_p = onnxruntime::CreateExecutionProviderFactory_CPU(0)->CreateProvider();
|
||||
}
|
||||
}
|
||||
else{
|
||||
auto shared_lib_path_it = additional_provider_libs_.find(provider_type);
|
||||
if (shared_lib_path_it == additional_provider_libs_.end()){
|
||||
return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME,
|
||||
common::StatusCode::INVALID_ARGUMENT,
|
||||
"Execution provider: " + provider_type + " is not supported.");
|
||||
}
|
||||
|
||||
void* handle;
|
||||
auto lib_path = shared_lib_path_it->second.first;
|
||||
auto entry_point = shared_lib_path_it->second.second;
|
||||
auto error = Env::Default().LoadDynamicLibrary(lib_path, false, &handle);
|
||||
if (!error.IsOK()) {
|
||||
return onnxruntime::Status(common::StatusCategory::ONNXRUNTIME,
|
||||
common::StatusCode::INVALID_ARGUMENT,
|
||||
"Load shared execution provider: " + provider_type + " failed: "
|
||||
+ error.ErrorMessage());
|
||||
}
|
||||
|
||||
Provider* (*PGetProvider)();
|
||||
ORT_RETURN_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle, entry_point, (void**)&PGetProvider));
|
||||
|
||||
Provider* provider = PGetProvider();
|
||||
std::shared_ptr<IExecutionProviderFactory> ep_factory = provider->CreateExecutionProviderFactory(&provider_options);
|
||||
provider_p = ep_factory->CreateProvider();
|
||||
}
|
||||
|
||||
auto ep = onnxruntime::python::GetOrCreateExecutionProvider(provider_type,
|
||||
ProviderOptionsMap{{provider_type, provider_options}},
|
||||
SessionOptions{});
|
||||
|
||||
auto invoker =
|
||||
std::make_unique<onnxruntime::ORTInvoker>(
|
||||
std::move(provider_p),
|
||||
std::move(ep),
|
||||
logger_,
|
||||
custom_op_schema_);
|
||||
|
||||
|
|
|
|||
|
|
@ -19,10 +19,6 @@ class ORTBackendsManager {
|
|||
public:
|
||||
ORTBackendsManager(const onnxruntime::logging::Logger& logger);
|
||||
|
||||
void RegisterProviderLib(const std::string& provider_type,
|
||||
const std::string& lib_path,
|
||||
const std::string& entry_point);
|
||||
|
||||
onnxruntime::Status set_device(size_t device_index, const std::string& provider_type,
|
||||
const onnxruntime::ProviderOptions& provider_options);
|
||||
|
||||
|
|
@ -34,7 +30,6 @@ private:
|
|||
//custom op schema registry
|
||||
//TODO: we might want to support load custom op schema on the fly
|
||||
onnxruntime::IOnnxRuntimeOpSchemaRegistryList custom_op_schema_ = {};
|
||||
std::unordered_map<std::string, std::pair<std::string, std::string> > additional_provider_libs_ = {};
|
||||
};
|
||||
|
||||
ORTBackendsManager& GetORTBackendsManager();
|
||||
|
|
|
|||
|
|
@ -52,15 +52,6 @@ void addObjectMethodsForEager(py::module& m){
|
|||
return ORTTensor_FromDLPack(dlpack_tensor);
|
||||
});
|
||||
|
||||
m.def("_register_provider_lib", [](const std::string& name,
|
||||
const std::string& provider_shared_lib_path,
|
||||
const std::string& provider_factory_entry) {
|
||||
torch_ort::eager::GetORTBackendsManager().RegisterProviderLib(name, provider_shared_lib_path, provider_factory_entry);
|
||||
},
|
||||
py::arg("name"),
|
||||
py::arg("provider_shared_lib_path"),
|
||||
py::arg("provider_factory_entry") = kDefaultExecutionProviderEntry);
|
||||
|
||||
m.def("set_device", [](size_t device_index,
|
||||
const std::string& provider_type,
|
||||
const std::unordered_map<std::string, std::string>& arguments){
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class OrtEPTests(unittest.TestCase):
|
|||
def test_import_custom_eps(self):
|
||||
torch_ort.set_device(0, 'CPUExecutionProvider', {})
|
||||
|
||||
torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), 'ProviderEntryPoint')
|
||||
torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), {})
|
||||
torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'})
|
||||
ort_device = torch_ort.device(1)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/logging/severity.h"
|
||||
#include "core/providers/get_execution_providers.h"
|
||||
|
||||
#include "core/platform/env.h"
|
||||
#include "core/session/provider_bridge_ort.h"
|
||||
|
|
@ -20,7 +21,7 @@ namespace py = pybind11;
|
|||
using namespace onnxruntime::logging;
|
||||
|
||||
std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
|
||||
InferenceSession* sess,
|
||||
const SessionOptions& session_options,
|
||||
const std::string& type,
|
||||
const ProviderOptionsMap& provider_options_map);
|
||||
|
||||
|
|
@ -41,6 +42,7 @@ void addObjectMethodsForEager(py::module& m);
|
|||
void InitArray();
|
||||
|
||||
using ExecutionProviderMap = std::unordered_map<std::string, std::shared_ptr<IExecutionProvider> >;
|
||||
using ExecutionProviderLibInfoMap = std::unordered_map<std::string, std::pair<std::string, ProviderOptions> > ;
|
||||
|
||||
bool GetDyanmicExecutionProviderHash(
|
||||
const std::string& ep_shared_lib_path,
|
||||
|
|
@ -137,6 +139,8 @@ public:
|
|||
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
|
||||
&SessionObjectInitializer::default_logger_id),
|
||||
ort_env_));
|
||||
auto& builtinEPs = GetAvailableExecutionProviderNames();
|
||||
available_training_eps_.assign(builtinEPs.begin(), builtinEPs.end());
|
||||
}
|
||||
|
||||
Environment& GetORTEnv(){
|
||||
|
|
@ -156,6 +160,20 @@ public:
|
|||
std::move(execution_provider)});
|
||||
}
|
||||
|
||||
void RegisterExtExecutionProviderInfo(const std::string& provider_type,
|
||||
const std::string& provider_lib_path,
|
||||
const ProviderOptions& default_options){
|
||||
ext_execution_provider_info_map_.insert({provider_type, {provider_lib_path, default_options}});
|
||||
if (std::find(available_training_eps_.begin(), available_training_eps_.end(), provider_type) == available_training_eps_.end())
|
||||
available_training_eps_.push_back(provider_type);
|
||||
}
|
||||
|
||||
const std::vector<std::string>& GetAvailableTrainingExecutionProviderTypes(){
|
||||
return available_training_eps_;
|
||||
}
|
||||
|
||||
ExecutionProviderLibInfoMap ext_execution_provider_info_map_;
|
||||
|
||||
private:
|
||||
std::string GetExecutionProviderMapKey(const std::string& provider_type,
|
||||
size_t hash){
|
||||
|
|
@ -166,6 +184,7 @@ private:
|
|||
|
||||
std::unique_ptr<Environment> ort_env_;
|
||||
ExecutionProviderMap execution_provider_instances_map_;
|
||||
std::vector<std::string> available_training_eps_;
|
||||
};
|
||||
|
||||
static std::unique_ptr<ORTTrainingPythonEnv> ort_training_env;
|
||||
|
|
@ -199,30 +218,72 @@ Environment& GetTrainingORTEnv() {
|
|||
return ort_training_env->GetORTEnv();
|
||||
}
|
||||
|
||||
void ResolveExtraProviderOptions(const std::string& provider_type,
|
||||
const ProviderOptionsMap& original_provider_options_map,
|
||||
ProviderOptionsMap& merged_options){
|
||||
auto& training_env = GetTrainingEnv();
|
||||
auto it = training_env.ext_execution_provider_info_map_.find(provider_type);
|
||||
if (it == training_env.ext_execution_provider_info_map_.end()){
|
||||
//nothing changed.
|
||||
merged_options = original_provider_options_map;
|
||||
}else{
|
||||
ProviderOptions options = it->second.second;
|
||||
options.insert({kExecutionProviderSharedLibraryPath, it->second.first});
|
||||
auto original_map_it = original_provider_options_map.find(provider_type);
|
||||
if (original_map_it != original_provider_options_map.end()){
|
||||
for (auto [k, v] : original_map_it->second){
|
||||
options.insert({k, v});
|
||||
}
|
||||
}
|
||||
merged_options[provider_type] = options;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> CreateTrainingEP(
|
||||
const SessionOptions& session_options,
|
||||
const std::string& provider_type,
|
||||
const ProviderOptionsMap& provider_options_map){
|
||||
ORTTrainingPythonEnv& training_env = GetTrainingEnv();
|
||||
if (training_env.ext_execution_provider_info_map_.find(provider_type) !=
|
||||
training_env.ext_execution_provider_info_map_.end()){
|
||||
ProviderOptionsMap merged_options;
|
||||
ResolveExtraProviderOptions(provider_type, provider_options_map, merged_options);
|
||||
return CreateExecutionProviderInstance(session_options, provider_type, merged_options);
|
||||
}else{
|
||||
return CreateExecutionProviderInstance(session_options, provider_type, provider_options_map);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProvider> GetOrCreateExecutionProvider(const std::string& provider_type,
|
||||
const ProviderOptionsMap& provider_options_map,
|
||||
const SessionOptions& session_options){
|
||||
ORTTrainingPythonEnv& training_env = GetTrainingEnv();
|
||||
// search in environment
|
||||
size_t hash;
|
||||
if (GetProviderInstanceHash(provider_type, provider_options_map, hash)){
|
||||
auto cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
|
||||
if (!cached_provider_instance){
|
||||
auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map);
|
||||
if (ep){
|
||||
training_env.AddExecutionProvider(provider_type, hash, std::move(ep));
|
||||
cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
|
||||
}
|
||||
}
|
||||
return cached_provider_instance;
|
||||
}
|
||||
else{
|
||||
// the EP doesn't support cache, register the instance to session
|
||||
auto ep = CreateTrainingEP(session_options, provider_type, provider_options_map);
|
||||
return ep;
|
||||
}
|
||||
}
|
||||
|
||||
void ORTTrainingRegisterExecutionProviders(InferenceSession* sess, const std::vector<std::string>& provider_types,
|
||||
const ProviderOptionsMap& provider_options_map) {
|
||||
// search in environment
|
||||
ORTTrainingPythonEnv& training_env = GetTrainingEnv();
|
||||
for (auto provider_type : provider_types){
|
||||
size_t hash;
|
||||
if (GetProviderInstanceHash(provider_type, provider_options_map, hash)){
|
||||
auto cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
|
||||
if (!cached_provider_instance){
|
||||
auto ep = CreateExecutionProviderInstance(sess, provider_type, provider_options_map);
|
||||
if (ep){
|
||||
training_env.AddExecutionProvider(provider_type, hash, std::move(ep));
|
||||
cached_provider_instance = training_env.GetExecutionProviderInstance(provider_type, hash);
|
||||
}
|
||||
}
|
||||
if (cached_provider_instance)
|
||||
OrtPybindThrowIfError(sess->RegisterExecutionProvider(cached_provider_instance));
|
||||
}
|
||||
else{
|
||||
// the EP doesn't support cache, register the instance to session
|
||||
auto ep = CreateExecutionProviderInstance(sess, provider_type, provider_options_map);
|
||||
if (ep)
|
||||
OrtPybindThrowIfError(sess->RegisterExecutionProvider(std::move(ep)));
|
||||
}
|
||||
auto ep = GetOrCreateExecutionProvider(provider_type, provider_options_map, sess->GetSessionOptions());
|
||||
if (ep)
|
||||
OrtPybindThrowIfError(sess->RegisterExecutionProvider(ep));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -250,6 +311,19 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
|
|||
#ifdef ENABLE_EAGER_MODE
|
||||
addObjectMethodsForEager(m);
|
||||
#endif
|
||||
|
||||
m.def("_register_provider_lib", [](const std::string& name,
|
||||
const std::string& provider_shared_lib_path,
|
||||
const ProviderOptions& default_options) {
|
||||
GetTrainingEnv().RegisterExtExecutionProviderInfo(name, provider_shared_lib_path, default_options);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"get_available_providers", []() -> const std::vector<std::string>& {
|
||||
return GetTrainingEnv().GetAvailableTrainingExecutionProviderTypes(); },
|
||||
"Return list of available Execution Providers in this installed version of Onnxruntime. "
|
||||
"The order of elements represents the default priority order of Execution Providers "
|
||||
"from highest to lowest.");
|
||||
|
||||
// clean the ort training environment when python interpreter exit
|
||||
// otherwise the global var will be de-constrcut after user main.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,29 @@
|
|||
import unittest
|
||||
import onnxruntime_pybind11_state as C
|
||||
import os
|
||||
|
||||
class EPRegistrationTests(unittest.TestCase):
|
||||
def get_test_execution_provider_path(self):
|
||||
return os.path.join('.', 'libtest_execution_provider.so')
|
||||
|
||||
def test_register_custom_eps(self):
|
||||
C._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), {'some_config':'val'})
|
||||
|
||||
assert 'TestExecutionProvider' in C.get_available_providers()
|
||||
|
||||
this = os.path.dirname(__file__)
|
||||
custom_op_model = os.path.join(this, "testdata", "custom_execution_provider_library", "test_model.onnx")
|
||||
if not os.path.exists(custom_op_model):
|
||||
raise FileNotFoundError("Unable to find '{0}'".format(custom_op_model))
|
||||
|
||||
session_options = C.get_default_session_options()
|
||||
sess = C.InferenceSession(session_options, custom_op_model, True, True)
|
||||
sess.initialize_session(['TestExecutionProvider'],
|
||||
[{'device_id':'0'}],
|
||||
set())
|
||||
print("Created session with customize execution provider successfully!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in a new issue