Add C++ wrapper for GetAvailableProviders() C API (#4313)

This commit is contained in:
Prabhat 2020-06-25 13:11:55 +05:30 committed by GitHub
parent a6d10376df
commit 151ef1c8a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 2 deletions

View file

@ -56,6 +56,10 @@ const OrtApi& Global<T>::api_ = *OrtGetApiBase()->GetApi(ORT_API_VERSION);
// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions
inline const OrtApi& GetApi() { return Global<void>::api_; }
// This is a C++ wrapper for GetAvailableProviders() C API and returns
// a vector of strings representing the available execution providers.
std::vector<std::string> GetAvailableProviders();
// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
// This can't be done in the C API since C doesn't have function overloading.
#define ORT_DEFINE_RELEASE(NAME) \
@ -374,4 +378,4 @@ struct CustomOpBase : OrtCustomOp {
} // namespace Ort
#include "onnxruntime_cxx_inline.h"
#include "onnxruntime_cxx_inline.h"

View file

@ -615,4 +615,14 @@ inline SessionOptions& SessionOptions::DisablePerSessionThreads() {
ThrowOnError(Global<void>::api_.DisablePerSessionThreads(p_));
return *this;
}
} // namespace Ort
inline std::vector<std::string> GetAvailableProviders() {
int len;
char **providers;
const OrtApi& api = GetApi();
ThrowOnError(api.GetAvailableProviders(&providers, &len));
std::vector<std::string> available_providers(providers, providers + len);
ThrowOnError(api.ReleaseAvailableProviders(providers, len));
return available_providers;
}
} // namespace Ort

View file

@ -592,3 +592,9 @@ TEST(CApiTest, get_available_providers) {
ASSERT_EQ(strcmp(providers[0], "CPUExecutionProvider"), 0);
ASSERT_EQ(g_ort->ReleaseAvailableProviders(providers, len), nullptr);
}
TEST(CApiTest, get_available_providers_cpp) {
std::vector<std::string> providers = Ort::GetAvailableProviders();
ASSERT_TRUE(providers.size() > 0);
ASSERT_TRUE(providers[0] == std::string("CPUExecutionProvider"));
}