Added GetAvailableProviders() to C API (#4247)

* Added GetAvailableProviders to C API

* Fix API version and Windows build error

* Changed function name

* Changed ORT_API_VERSION to 4

* Moved all_providers array to constants.h

* Move check for providers to constants.h

* Changed name of array to avoid warning

* Address review comment

* Added unit test
This commit is contained in:
Prabhat 2020-06-22 07:40:25 +05:30 committed by GitHub
parent 175983c082
commit 57fabfba7a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 129 additions and 7 deletions

View file

@ -38,4 +38,46 @@ constexpr const char* kDmlExecutionProvider = "DmlExecutionProvider";
constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider";
constexpr const char* kAclExecutionProvider = "ACLExecutionProvider";
constexpr const char* kArmNNExecutionProvider = "ArmNNExecutionProvider";
constexpr const char *providers_available[] = {
kCpuExecutionProvider,
#ifdef USE_CUDA
kCudaExecutionProvider,
#endif
#ifdef USE_DNNL
kDnnlExecutionProvider,
#endif
#ifdef USE_NGRAPH
kNGraphExecutionProvider,
#endif
#ifdef USE_OPENVINO
kOpenVINOExecutionProvider,
#endif
#ifdef USE_NUPHAR
kNupharExecutionProvider,
#endif
#ifdef USE_VITISAI
kVitisAIExecutionProvider,
#endif
#ifdef USE_TENSORRT
kTensorrtExecutionProvider,
#endif
#ifdef USE_NNAPI
kNnapiExecutionProvider,
#endif
#ifdef USE_RKNPU
kRknpuExecutionProvider,
#endif
#ifdef USE_DML
kDmlExecutionProvider,
#endif
#ifdef USE_MIGRAPHX
kMIGraphXExecutionProvider,
#endif
#ifdef USE_ACL
kAclExecutionProvider,
#endif
#ifdef USE_ARMNN
kArmNNExecutionProvider,
#endif
};
} // namespace onnxruntime

View file

@ -7,7 +7,7 @@
#include <string.h>
// This value is used in structures passed to ORT so that a newer version of ORT will still work with them
#define ORT_API_VERSION 3
#define ORT_API_VERSION 4
#ifdef __cplusplus
extern "C" {
@ -817,6 +817,25 @@ struct OrtApi {
ORT_API2_STATUS(AddFreeDimensionOverrideByName,
_Inout_ OrtSessionOptions* options, _In_ const char* dim_name,
_In_ int64_t dim_value);
/**
* \param out_ptr will hold a pointer to the array of char *
* representing available providers.
* \param provider_length is a pointer to an int variable where
* the number of available providers will be added.
* The caller is responsible for freeing each char * and the pointer
* array by calling ReleaseAvailableProviders().
*/
ORT_API2_STATUS(GetAvailableProviders, _Outptr_ char ***out_ptr,
_In_ int *provider_length);
/**
* \param ptr is the pointer to an array of available providers you
* get after calling GetAvailableProviders().
* \param providers_length is the number of available providers.
*/
ORT_API2_STATUS(ReleaseAvailableProviders, _In_ char **ptr,
_In_ int providers_length);
};
/*

View file

@ -14,6 +14,7 @@
#include "core/common/logging/logging.h"
#include "core/common/status.h"
#include "core/common/safeint.h"
#include "core/graph/constants.h"
#include "core/graph/graph.h"
#include "core/framework/allocator.h"
#include "core/framework/tensor.h"
@ -981,7 +982,7 @@ ORT_STATUS_PTR OrtGetValueImplSeqOfTensors(_In_ const OrtValue* p_ml_value, int
return t_disp.template InvokeWithUnsupportedPolicy<UnsupportedReturnFailStatus>(allocator, one_tensor, out);
}
#ifdef _MSVC_VER
#ifdef _MSC_VER
#pragma warning(pop)
#endif
@ -1375,6 +1376,46 @@ ORT_API_STATUS_IMPL(OrtApis::GetOpaqueValue, _In_ const char* domain_name, _In_
return nullptr;
}
ORT_API_STATUS_IMPL(OrtApis::GetAvailableProviders, _Outptr_ char ***out_ptr,
_In_ int *providers_length) {
API_IMPL_BEGIN
const size_t MAX_LEN = 30;
int available_count = (int)(sizeof(providers_available) / sizeof(char *));
char **out = (char **)malloc(available_count * sizeof(char *));
if(out) {
for(int i = 0; i < available_count; i++) {
out[i] = (char *)malloc((MAX_LEN + 1) * sizeof(char));
if(out[i]) {
#ifdef _MSC_VER
strncpy_s(out[i], MAX_LEN, providers_available[i], MAX_LEN);
#else
strncpy(out[i], providers_available[i], MAX_LEN);
#endif
out[i][MAX_LEN] = '\0';
}
}
}
*providers_length = available_count;
*out_ptr = out;
API_IMPL_END
return NULL;
}
ORT_API_STATUS_IMPL(OrtApis::ReleaseAvailableProviders, _In_ char **ptr,
_In_ int providers_length) {
API_IMPL_BEGIN
if(ptr) {
for(int i = 0; i < providers_length; i++) {
if(ptr[i]) {
free(ptr[i]);
}
}
free(ptr);
}
API_IMPL_END
return NULL;
}
// End support for non-tensor types
static constexpr OrtApiBase ort_api_base = {
@ -1421,7 +1462,7 @@ Second example, if we wanted to add and remove some members, we'd do this:
In GetApi we now make it return ort_api_3 for version 3.
*/
static constexpr OrtApi ort_api_1_to_3 = {
static constexpr OrtApi ort_api_1_to_4 = {
// NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering.
// Shipped as version 1 - DO NOT MODIFY (see above text for more information)
@ -1560,21 +1601,26 @@ static constexpr OrtApi ort_api_1_to_3 = {
&OrtApis::ReleaseModelMetadata,
// End of Version 2 - DO NOT MODIFY ABOVE (see above text for more information)
// Version 3 - In development, feel free to add/remove/rearrange here
&OrtApis::CreateEnvWithGlobalThreadPools,
&OrtApis::DisablePerSessionThreads,
&OrtApis::CreateThreadingOptions,
&OrtApis::ReleaseThreadingOptions,
&OrtApis::ModelMetadataGetCustomMetadataMapKeys,
&OrtApis::AddFreeDimensionOverrideByName};
&OrtApis::AddFreeDimensionOverrideByName,
// End of Version 3 - DO NOT MODIFY ABOVE (see above text for more information)
// Version 4 - In development, feel free to add/remove/rearrange here
&OrtApis::GetAvailableProviders,
&OrtApis::ReleaseAvailableProviders,
};
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
// If this assert hits, read the above 'Rules on how to add a new Ort API version'
static_assert(offsetof(OrtApi, ReleaseCustomOpDomain) / sizeof(void*) == 101, "Size of version 1 API cannot change");
ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) {
if (version >= 1 && version <= 3)
return &ort_api_1_to_3;
if (version >= 1 && version <= 4)
return &ort_api_1_to_4;
return nullptr; // Unsupported version
}

View file

@ -198,4 +198,8 @@ ORT_API_STATUS_IMPL(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMe
ORT_API_STATUS_IMPL(AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, _In_ int64_t dim_value);
ORT_API_STATUS_IMPL(GetAvailableProviders, _Outptr_ char ***out_ptr,
_In_ int *providers_length);
ORT_API_STATUS_IMPL(ReleaseAvailableProviders, _In_ char **ptr,
_In_ int providers_length);
} // namespace OrtApis

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include <core/common/make_unique.h>
#include "core/session/onnxruntime_c_api.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/graph/constants.h"
#include "providers.h"
@ -581,3 +582,13 @@ TEST(CApiTest, model_metadata) {
ASSERT_TRUE(custom_metadata_map_keys == nullptr);
}
}
TEST(CApiTest, get_available_providers) {
const OrtApi *g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION);
int len = 0;
char **providers;
ASSERT_EQ(g_ort->GetAvailableProviders(&providers, &len), nullptr);
ASSERT_TRUE(len > 0);
ASSERT_EQ(strcmp(providers[0], "CPUExecutionProvider"), 0);
ASSERT_EQ(g_ort->ReleaseAvailableProviders(providers, len), nullptr);
}