mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Added GetAvailableProviders() to C API (#4247)
* Added GetAvailableProviders to C API * Fix API version and Windows build error * Changed function name * Changed ORT_API_VERSION to 4 * Moved all_providers array to constants.h * Move check for providers to constants.h * Changed name of array to avoid warning * Address review comment * Added unit test
This commit is contained in:
parent
175983c082
commit
57fabfba7a
5 changed files with 129 additions and 7 deletions
|
|
@ -38,4 +38,46 @@ constexpr const char* kDmlExecutionProvider = "DmlExecutionProvider";
|
|||
constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider";
|
||||
constexpr const char* kAclExecutionProvider = "ACLExecutionProvider";
|
||||
constexpr const char* kArmNNExecutionProvider = "ArmNNExecutionProvider";
|
||||
constexpr const char *providers_available[] = {
|
||||
kCpuExecutionProvider,
|
||||
#ifdef USE_CUDA
|
||||
kCudaExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_DNNL
|
||||
kDnnlExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_NGRAPH
|
||||
kNGraphExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_OPENVINO
|
||||
kOpenVINOExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_NUPHAR
|
||||
kNupharExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_VITISAI
|
||||
kVitisAIExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_TENSORRT
|
||||
kTensorrtExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_NNAPI
|
||||
kNnapiExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_RKNPU
|
||||
kRknpuExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_DML
|
||||
kDmlExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_MIGRAPHX
|
||||
kMIGraphXExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_ACL
|
||||
kAclExecutionProvider,
|
||||
#endif
|
||||
#ifdef USE_ARMNN
|
||||
kArmNNExecutionProvider,
|
||||
#endif
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
#include <string.h>
|
||||
|
||||
// This value is used in structures passed to ORT so that a newer version of ORT will still work with them
|
||||
#define ORT_API_VERSION 3
|
||||
#define ORT_API_VERSION 4
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
@ -817,6 +817,25 @@ struct OrtApi {
|
|||
ORT_API2_STATUS(AddFreeDimensionOverrideByName,
|
||||
_Inout_ OrtSessionOptions* options, _In_ const char* dim_name,
|
||||
_In_ int64_t dim_value);
|
||||
|
||||
/**
|
||||
* \param out_ptr will hold a pointer to the array of char *
|
||||
* representing available providers.
|
||||
* \param provider_length is a pointer to an int variable where
|
||||
* the number of available providers will be added.
|
||||
* The caller is responsible for freeing each char * and the pointer
|
||||
* array by calling ReleaseAvailableProviders().
|
||||
*/
|
||||
ORT_API2_STATUS(GetAvailableProviders, _Outptr_ char ***out_ptr,
|
||||
_In_ int *provider_length);
|
||||
|
||||
/**
|
||||
* \param ptr is the pointer to an array of available providers you
|
||||
* get after calling GetAvailableProviders().
|
||||
* \param providers_length is the number of available providers.
|
||||
*/
|
||||
ORT_API2_STATUS(ReleaseAvailableProviders, _In_ char **ptr,
|
||||
_In_ int providers_length);
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/status.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "core/graph/graph.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/tensor.h"
|
||||
|
|
@ -981,7 +982,7 @@ ORT_STATUS_PTR OrtGetValueImplSeqOfTensors(_In_ const OrtValue* p_ml_value, int
|
|||
return t_disp.template InvokeWithUnsupportedPolicy<UnsupportedReturnFailStatus>(allocator, one_tensor, out);
|
||||
}
|
||||
|
||||
#ifdef _MSVC_VER
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
|
||||
|
|
@ -1375,6 +1376,46 @@ ORT_API_STATUS_IMPL(OrtApis::GetOpaqueValue, _In_ const char* domain_name, _In_
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char ***out_ptr,
|
||||
_In_ int *providers_length) {
|
||||
API_IMPL_BEGIN
|
||||
const size_t MAX_LEN = 30;
|
||||
int available_count = (int)(sizeof(providers_available) / sizeof(char *));
|
||||
char **out = (char **)malloc(available_count * sizeof(char *));
|
||||
if(out) {
|
||||
for(int i = 0; i < available_count; i++) {
|
||||
out[i] = (char *)malloc((MAX_LEN + 1) * sizeof(char));
|
||||
if(out[i]) {
|
||||
#ifdef _MSC_VER
|
||||
strncpy_s(out[i], MAX_LEN, providers_available[i], MAX_LEN);
|
||||
#else
|
||||
strncpy(out[i], providers_available[i], MAX_LEN);
|
||||
#endif
|
||||
out[i][MAX_LEN] = '\0';
|
||||
}
|
||||
}
|
||||
}
|
||||
*providers_length = available_count;
|
||||
*out_ptr = out;
|
||||
API_IMPL_END
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char **ptr,
|
||||
_In_ int providers_length) {
|
||||
API_IMPL_BEGIN
|
||||
if(ptr) {
|
||||
for(int i = 0; i < providers_length; i++) {
|
||||
if(ptr[i]) {
|
||||
free(ptr[i]);
|
||||
}
|
||||
}
|
||||
free(ptr);
|
||||
}
|
||||
API_IMPL_END
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// End support for non-tensor types
|
||||
|
||||
static constexpr OrtApiBase ort_api_base = {
|
||||
|
|
@ -1421,7 +1462,7 @@ Second example, if we wanted to add and remove some members, we'd do this:
|
|||
In GetApi we now make it return ort_api_3 for version 3.
|
||||
*/
|
||||
|
||||
static constexpr OrtApi ort_api_1_to_3 = {
|
||||
static constexpr OrtApi ort_api_1_to_4 = {
|
||||
// NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering.
|
||||
|
||||
// Shipped as version 1 - DO NOT MODIFY (see above text for more information)
|
||||
|
|
@ -1560,21 +1601,26 @@ static constexpr OrtApi ort_api_1_to_3 = {
|
|||
&OrtApis::ReleaseModelMetadata,
|
||||
// End of Version 2 - DO NOT MODIFY ABOVE (see above text for more information)
|
||||
|
||||
// Version 3 - In development, feel free to add/remove/rearrange here
|
||||
&OrtApis::CreateEnvWithGlobalThreadPools,
|
||||
&OrtApis::DisablePerSessionThreads,
|
||||
&OrtApis::CreateThreadingOptions,
|
||||
&OrtApis::ReleaseThreadingOptions,
|
||||
&OrtApis::ModelMetadataGetCustomMetadataMapKeys,
|
||||
&OrtApis::AddFreeDimensionOverrideByName};
|
||||
&OrtApis::AddFreeDimensionOverrideByName,
|
||||
// End of Version 3 - DO NOT MODIFY ABOVE (see above text for more information)
|
||||
|
||||
// Version 4 - In development, feel free to add/remove/rearrange here
|
||||
&OrtApis::GetAvailableProviders,
|
||||
&OrtApis::ReleaseAvailableProviders,
|
||||
};
|
||||
|
||||
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
|
||||
// If this assert hits, read the above 'Rules on how to add a new Ort API version'
|
||||
static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 101, "Size of version 1 API cannot change");
|
||||
|
||||
ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) {
|
||||
if (version >= 1 && version <= 3)
|
||||
return &ort_api_1_to_3;
|
||||
if (version >= 1 && version <= 4)
|
||||
return &ort_api_1_to_4;
|
||||
|
||||
return nullptr; // Unsupported version
|
||||
}
|
||||
|
|
|
|||
|
|
@ -198,4 +198,8 @@ ORT_API_STATUS_IMPL(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMe
|
|||
|
||||
ORT_API_STATUS_IMPL(AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, _In_ int64_t dim_value);
|
||||
|
||||
ORT_API_STATUS_IMPL(GetAvailableProviders, _Outptr_ char ***out_ptr,
|
||||
_In_ int *providers_length);
|
||||
ORT_API_STATUS_IMPL(ReleaseAvailableProviders, _In_ char **ptr,
|
||||
_In_ int providers_length);
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include <core/common/make_unique.h>
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "core/session/onnxruntime_cxx_api.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "providers.h"
|
||||
|
|
@ -581,3 +582,13 @@ TEST(CApiTest, model_metadata) {
|
|||
ASSERT_TRUE(custom_metadata_map_keys == nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CApiTest, get_available_providers) {
|
||||
const OrtApi *g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
|
||||
int len = 0;
|
||||
char **providers;
|
||||
ASSERT_EQ(g_ort->GetAvailableProviders(&providers, &len), nullptr);
|
||||
ASSERT_TRUE(len > 0);
|
||||
ASSERT_EQ(strcmp(providers[0], "CPUExecutionProvider"), 0);
|
||||
ASSERT_EQ(g_ort->ReleaseAvailableProviders(providers, len), nullptr);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue