diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index de6886bcdd..cd3f76fa8e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -56,6 +56,10 @@ const OrtApi& Global::api_ = *OrtGetApiBase()->GetApi(ORT_API_VERSION); // This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions inline const OrtApi& GetApi() { return Global::api_; } +// This is a C++ wrapper for GetAvailableProviders() C API and returns +// a vector of strings representing the available execution providers. +std::vector GetAvailableProviders(); + // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type // This can't be done in the C API since C doesn't have function overloading. #define ORT_DEFINE_RELEASE(NAME) \ @@ -374,4 +378,4 @@ struct CustomOpBase : OrtCustomOp { } // namespace Ort -#include "onnxruntime_cxx_inline.h" \ No newline at end of file +#include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 5c60743cf7..2041f92f66 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -615,4 +615,14 @@ inline SessionOptions& SessionOptions::DisablePerSessionThreads() { ThrowOnError(Global::api_.DisablePerSessionThreads(p_)); return *this; } -} // namespace Ort \ No newline at end of file + +inline std::vector GetAvailableProviders() { + int len; + char **providers; + const OrtApi& api = GetApi(); + ThrowOnError(api.GetAvailableProviders(&providers, &len)); + std::vector available_providers(providers, providers + len); + ThrowOnError(api.ReleaseAvailableProviders(providers, len)); + return available_providers; +} +} // namespace Ort diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 0e4fb474ad..2a4a23391e 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -592,3 +592,9 @@ TEST(CApiTest, get_available_providers) { ASSERT_EQ(strcmp(providers[0], "CPUExecutionProvider"), 0); ASSERT_EQ(g_ort->ReleaseAvailableProviders(providers, len), nullptr); } + +TEST(CApiTest, get_available_providers_cpp) { + std::vector providers = Ort::GetAvailableProviders(); + ASSERT_TRUE(providers.size() > 0); + ASSERT_TRUE(providers[0] == std::string("CPUExecutionProvider")); +}