From 57fabfba7a97ee8b7a8df162c791be2ade32846b Mon Sep 17 00:00:00 2001 From: Prabhat Date: Mon, 22 Jun 2020 07:40:25 +0530 Subject: [PATCH] Added GetAvailableProviders() to C API (#4247) * Added GetAvailableProviders to C API * Fix API version and Windows build error * Changed function name * Changed ORT_API_VERSION to 4 * Moved all_providers array to constants.h * Move check for providers to constants.h * Changed name of array to avoid warning * Address review comment * Added unit test --- include/onnxruntime/core/graph/constants.h | 42 ++++++++++++++ .../core/session/onnxruntime_c_api.h | 21 ++++++- onnxruntime/core/session/onnxruntime_c_api.cc | 58 +++++++++++++++++-- onnxruntime/core/session/ort_apis.h | 4 ++ onnxruntime/test/shared_lib/test_inference.cc | 11 ++++ 5 files changed, 129 insertions(+), 7 deletions(-) diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 1cf3677e2e..d213874038 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -38,4 +38,46 @@ constexpr const char* kDmlExecutionProvider = "DmlExecutionProvider"; constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; constexpr const char* kAclExecutionProvider = "ACLExecutionProvider"; constexpr const char* kArmNNExecutionProvider = "ArmNNExecutionProvider"; +constexpr const char *providers_available[] = { + kCpuExecutionProvider, +#ifdef USE_CUDA + kCudaExecutionProvider, +#endif +#ifdef USE_DNNL + kDnnlExecutionProvider, +#endif +#ifdef USE_NGRAPH + kNGraphExecutionProvider, +#endif +#ifdef USE_OPENVINO + kOpenVINOExecutionProvider, +#endif +#ifdef USE_NUPHAR + kNupharExecutionProvider, +#endif +#ifdef USE_VITISAI + kVitisAIExecutionProvider, +#endif +#ifdef USE_TENSORRT + kTensorrtExecutionProvider, +#endif +#ifdef USE_NNAPI + kNnapiExecutionProvider, +#endif +#ifdef USE_RKNPU + kRknpuExecutionProvider, +#endif +#ifdef USE_DML + kDmlExecutionProvider, +#endif +#ifdef USE_MIGRAPHX + kMIGraphXExecutionProvider, +#endif +#ifdef USE_ACL + kAclExecutionProvider, +#endif +#ifdef USE_ARMNN + kArmNNExecutionProvider, +#endif +}; } // namespace onnxruntime diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b5dd599707..c94631620a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7,7 +7,7 @@ #include // This value is used in structures passed to ORT so that a newer version of ORT will still work with them -#define ORT_API_VERSION 3 +#define ORT_API_VERSION 4 #ifdef __cplusplus extern "C" { @@ -817,6 +817,25 @@ struct OrtApi { ORT_API2_STATUS(AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, _In_ int64_t dim_value); + + /** + * \param out_ptr will hold a pointer to the array of char * + * representing available providers. + * \param provider_length is a pointer to an int variable where + * the number of available providers will be added. + * The caller is responsible for freeing each char * and the pointer + * array by calling ReleaseAvailableProviders(). + */ + ORT_API2_STATUS(GetAvailableProviders, _Outptr_ char ***out_ptr, + _In_ int *provider_length); + + /** + * \param ptr is the pointer to an array of available providers you + * get after calling GetAvailableProviders(). + * \param providers_length is the number of available providers. + */ + ORT_API2_STATUS(ReleaseAvailableProviders, _In_ char **ptr, + _In_ int providers_length); }; /* diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4851cfa709..f7db1d9c58 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -14,6 +14,7 @@ #include "core/common/logging/logging.h" #include "core/common/status.h" #include "core/common/safeint.h" +#include "core/graph/constants.h" #include "core/graph/graph.h" #include "core/framework/allocator.h" #include "core/framework/tensor.h" @@ -981,7 +982,7 @@ ORT_STATUS_PTR OrtGetValueImplSeqOfTensors(_In_ const OrtValue* p_ml_value, int return t_disp.template InvokeWithUnsupportedPolicy(allocator, one_tensor, out); } -#ifdef _MSVC_VER +#ifdef _MSC_VER #pragma warning(pop) #endif @@ -1375,6 +1376,46 @@ ORT_API_STATUS_IMPL(OrtApis::GetOpaqueValue, _In_ const char* domain_name, _In_ return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char ***out_ptr, + _In_ int *providers_length) { + API_IMPL_BEGIN + const size_t MAX_LEN = 30; + int available_count = (int)(sizeof(providers_available) / sizeof(char *)); + char **out = (char **)malloc(available_count * sizeof(char *)); + if(out) { + for(int i = 0; i < available_count; i++) { + out[i] = (char *)malloc((MAX_LEN + 1) * sizeof(char)); + if(out[i]) { +#ifdef _MSC_VER + strncpy_s(out[i], MAX_LEN, providers_available[i], MAX_LEN); +#else + strncpy(out[i], providers_available[i], MAX_LEN); +#endif + out[i][MAX_LEN] = '\0'; + } + } + } + *providers_length = available_count; + *out_ptr = out; + API_IMPL_END + return NULL; +} + +ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char **ptr, + _In_ int providers_length) { + API_IMPL_BEGIN + if(ptr) { + for(int i = 0; i < providers_length; i++) { + if(ptr[i]) { + free(ptr[i]); + } + } + free(ptr); + } + API_IMPL_END + return NULL; +} + // End support for non-tensor types static constexpr OrtApiBase ort_api_base = { @@ -1421,7 +1462,7 @@ Second example, if we wanted to add and remove some members, we'd do this: In GetApi we now make it return ort_api_3 for version 3. */ -static constexpr OrtApi ort_api_1_to_3 = { +static constexpr OrtApi ort_api_1_to_4 = { // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. // Shipped as version 1 - DO NOT MODIFY (see above text for more information) @@ -1560,21 +1601,26 @@ static constexpr OrtApi ort_api_1_to_3 = { &OrtApis::ReleaseModelMetadata, // End of Version 2 - DO NOT MODIFY ABOVE (see above text for more information) - // Version 3 - In development, feel free to add/remove/rearrange here &OrtApis::CreateEnvWithGlobalThreadPools, &OrtApis::DisablePerSessionThreads, &OrtApis::CreateThreadingOptions, &OrtApis::ReleaseThreadingOptions, &OrtApis::ModelMetadataGetCustomMetadataMapKeys, - &OrtApis::AddFreeDimensionOverrideByName}; + &OrtApis::AddFreeDimensionOverrideByName, + // End of Version 3 - DO NOT MODIFY ABOVE (see above text for more information) + + // Version 4 - In development, feel free to add/remove/rearrange here + &OrtApis::GetAvailableProviders, + &OrtApis::ReleaseAvailableProviders, +}; // Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other) // If this assert hits, read the above 'Rules on how to add a new Ort API version' static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 101, "Size of version 1 API cannot change"); ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { - if (version >= 1 && version <= 3) - return &ort_api_1_to_3; + if (version >= 1 && version <= 4) + return &ort_api_1_to_4; return nullptr; // Unsupported version } diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 5c0764747c..160fe907b7 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -198,4 +198,8 @@ ORT_API_STATUS_IMPL(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMe ORT_API_STATUS_IMPL(AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, _In_ int64_t dim_value); +ORT_API_STATUS_IMPL(GetAvailableProviders, _Outptr_ char ***out_ptr, + _In_ int *providers_length); +ORT_API_STATUS_IMPL(ReleaseAvailableProviders, _In_ char **ptr, + _In_ int providers_length); } // namespace OrtApis diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 7a85b1f081..0e4fb474ad 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/graph/constants.h" #include "providers.h" @@ -581,3 +582,13 @@ TEST(CApiTest, model_metadata) { ASSERT_TRUE(custom_metadata_map_keys == nullptr); } } + +TEST(CApiTest, get_available_providers) { + const OrtApi *g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); + int len = 0; + char **providers; + ASSERT_EQ(g_ort->GetAvailableProviders(&providers, &len), nullptr); + ASSERT_TRUE(len > 0); + ASSERT_EQ(strcmp(providers[0], "CPUExecutionProvider"), 0); + ASSERT_EQ(g_ort->ReleaseAvailableProviders(providers, len), nullptr); +}