diff --git a/onnxruntime/core/framework/provider_bridge_ort.cc b/onnxruntime/core/framework/provider_bridge_ort.cc index 8dc938a055..a326845286 100644 --- a/onnxruntime/core/framework/provider_bridge_ort.cc +++ b/onnxruntime/core/framework/provider_bridge_ort.cc @@ -7,6 +7,7 @@ #include "core/framework/compute_capability.h" #include "core/framework/data_types.h" #include "core/framework/data_transfer_manager.h" +#include "core/framework/error_code_helper.h" #include "core/framework/execution_provider.h" #include "core/framework/kernel_registry.h" #include "core/framework/provider_bridge_ort.h" @@ -1096,6 +1097,7 @@ INcclService& INcclService::GetInstance() { } // namespace onnxruntime ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) { + API_IMPL_BEGIN auto factory = onnxruntime::CreateExecutionProviderFactory_Dnnl(use_arena); if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library"); @@ -1103,9 +1105,11 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessi options->provider_factories.push_back(factory); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id) { + API_IMPL_BEGIN auto factory = onnxruntime::CreateExecutionProviderFactory_Tensorrt(device_id); if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); @@ -1113,9 +1117,11 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS options->provider_factories.push_back(factory); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) { + API_IMPL_BEGIN auto factory = onnxruntime::CreateExecutionProviderFactory_Tensorrt(tensorrt_options); if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); @@ -1123,9 +1129,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In options->provider_factories.push_back(factory); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options) { + API_IMPL_BEGIN auto factory = onnxruntime::CreateExecutionProviderFactory_OpenVINO(provider_options); if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_OpenVINO: Failed to load shared library"); @@ -1133,10 +1141,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO, _In options->provider_factories.push_back(factory); return nullptr; + API_IMPL_END } ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, _In_ const char* device_type) { - OrtOpenVINOProviderOptions provider_options; + OrtOpenVINOProviderOptions provider_options{}; provider_options.device_type = device_type; return OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO(options, &provider_options); } @@ -1149,18 +1158,23 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessi } ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, _In_ int device_id) { + API_IMPL_BEGIN if (auto* info = onnxruntime::GetProviderInfo_CUDA()) return info->SetCurrentGpuDeviceId(device_id); return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled."); + API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, _In_ int* device_id) { + API_IMPL_BEGIN if (auto* info = onnxruntime::GetProviderInfo_CUDA()) return info->GetCurrentGpuDeviceId(device_id); return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled."); + API_IMPL_END } ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options) { + API_IMPL_BEGIN auto factory = onnxruntime::CreateExecutionProviderFactory_Cuda(cuda_options); if (!factory) { return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Cuda: Failed to load shared library"); @@ -1168,4 +1182,5 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_CUDA, _In_ Or options->provider_factories.push_back(factory); return nullptr; + API_IMPL_END }