Add API_IMPL_* blocks around shared provider methods as they are C APIs (#7908)

This commit is contained in:
Ryan Hill 2021-06-01 19:28:00 -07:00 committed by GitHub
parent f9587d6051
commit a9f7eef754
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -7,6 +7,7 @@
#include "core/framework/compute_capability.h"
#include "core/framework/data_types.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/execution_provider.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/provider_bridge_ort.h"
@ -1096,6 +1097,7 @@ INcclService& INcclService::GetInstance() {
} // namespace onnxruntime
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) {
API_IMPL_BEGIN
auto factory = onnxruntime::CreateExecutionProviderFactory_Dnnl(use_arena);
if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Dnnl: Failed to load shared library");
@ -1103,9 +1105,11 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessi
options->provider_factories.push_back(factory);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id) {
API_IMPL_BEGIN
auto factory = onnxruntime::CreateExecutionProviderFactory_Tensorrt(device_id);
if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library");
@ -1113,9 +1117,11 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS
options->provider_factories.push_back(factory);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) {
API_IMPL_BEGIN
auto factory = onnxruntime::CreateExecutionProviderFactory_Tensorrt(tensorrt_options);
if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library");
@ -1123,9 +1129,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In
options->provider_factories.push_back(factory);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options) {
API_IMPL_BEGIN
auto factory = onnxruntime::CreateExecutionProviderFactory_OpenVINO(provider_options);
if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_OpenVINO: Failed to load shared library");
@ -1133,10 +1141,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO, _In
options->provider_factories.push_back(factory);
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, _In_ const char* device_type) {
OrtOpenVINOProviderOptions provider_options;
OrtOpenVINOProviderOptions provider_options{};
provider_options.device_type = device_type;
return OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO(options, &provider_options);
}
@ -1149,18 +1158,23 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessi
}
ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, _In_ int device_id) {
API_IMPL_BEGIN
if (auto* info = onnxruntime::GetProviderInfo_CUDA())
return info->SetCurrentGpuDeviceId(device_id);
return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled.");
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, _In_ int* device_id) {
API_IMPL_BEGIN
if (auto* info = onnxruntime::GetProviderInfo_CUDA())
return info->GetCurrentGpuDeviceId(device_id);
return CreateStatus(ORT_FAIL, "CUDA execution provider is not enabled.");
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options) {
API_IMPL_BEGIN
auto factory = onnxruntime::CreateExecutionProviderFactory_Cuda(cuda_options);
if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Cuda: Failed to load shared library");
@ -1168,4 +1182,5 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_CUDA, _In_ Or
options->provider_factories.push_back(factory);
return nullptr;
API_IMPL_END
}