diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index d3481d3287..93d572879b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -430,7 +430,7 @@ namespace Microsoft.ML.OnnxRuntime IntPtr /*(const OrtSession*)*/ session, IntPtr /*(OrtAllocator*)*/ allocator, out IntPtr /*(char**)*/profile_file); - public static DOrtSessionEndProfiling OrtSessionEndProfiling; + public static DOrtSessionEndProfiling OrtSessionEndProfiling; public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerName( IntPtr /*(OrtSession*)*/ session, @@ -567,7 +567,7 @@ namespace Microsoft.ML.OnnxRuntime public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id); [DllImport(nativeLib, CharSet = charSet)] - public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nnapi(IntPtr /*(OrtSessionOptions*)*/ options); + public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nnapi(IntPtr /*(OrtSessionOptions*)*/ options, ulong nnapi_flags); [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nuphar(IntPtr /*(OrtSessionOptions*) */ options, int allow_unaligned_buffers, string settings); @@ -898,7 +898,7 @@ namespace Microsoft.ML.OnnxRuntime /// instance of OrtModelMetadata /// instance of OrtAllocator /// (output) producer name from the ModelMetadata instance - public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, + public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata, IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value); public static DOrtModelMetadataGetProducerName OrtModelMetadataGetProducerName; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs index 3b9cec2580..14fbea3ec5 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs @@ -155,9 +155,9 @@ namespace Microsoft.ML.OnnxRuntime /// /// Use only if you have the onnxruntime package specific to this Execution Provider. /// - public void AppendExecutionProvider_Nnapi() + public void AppendExecutionProvider_Nnapi(ulong nnapi_flags) { - NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Nnapi(handle)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Nnapi(handle, nnapi_flags)); } /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index 899785bc8b..97981bde3f 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -97,7 +97,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests opt.AppendExecutionProvider_CUDA(0); #endif #if USE_DML - // Explicitly set dll probe path so that the (potentially) stale system DirectML.dll + // Explicitly set dll probe path so that the (potentially) stale system DirectML.dll // doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll // The managed tests binary path already contains the right DirectML.dll, so use that @@ -122,7 +122,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests opt.AppendExecutionProvider_MIGraphX(0); #endif #if USE_NNAPI - opt.AppendExecutionProvider_Nnapi(); + opt.AppendExecutionProvider_Nnapi(0); #endif @@ -1770,7 +1770,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests } } - // TestGpu() will test the CUDA EP on CUDA enabled builds and + // TestGpu() will test the CUDA EP on CUDA enabled builds and // the DML EP on DML enabled builds [GpuFact] private void TestGpu() @@ -1979,9 +1979,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests var dims = new long[] { 3, 2 }; var dataBuffer = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F }; var dataHandle = GCHandle.Alloc(dataBuffer, GCHandleType.Pinned); - + try - { + { unsafe { float* p = (float*)dataHandle.AddrOfPinnedObject(); @@ -2354,7 +2354,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests { string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx"); #if USE_DML - // Explicitly set dll probe path so that the (potentially) stale system DirectML.dll + // Explicitly set dll probe path so that the (potentially) stale system DirectML.dll // doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll // The managed tests binary path already contains the right DirectML.dll, so use that diff --git a/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h b/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h index d8b6a1ec27..e9f68dc28a 100644 --- a/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h +++ b/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h @@ -1,15 +1,34 @@ // Copyright 2019 JD.com Inc. JD AI +#pragma once #include "onnxruntime_c_api.h" +// NNAPIFlags are bool options we want to set for NNAPI EP +// This enum is defined as bit flats, and cannot have negative value +// To generate a unsigned long nnapi_flags for using with OrtSessionOptionsAppendExecutionProvider_Nnapi below, +// unsigned long nnapi_flags = 0; +// nnapi_flags |= NNAPI_FLAG_USE_FP16; +enum NNAPIFlags { + NNAPI_FLAG_USE_NONE = 0x000, + + // Using fp16 relaxation in NNAPI EP, this may improve perf but may also reduce precision + NNAPI_FLAG_USE_FP16 = 0x001, + + // Use NCHW layout in NNAPI EP, this is only available after Android API level 29 + // Please note for now, NNAPI perform worse using NCHW compare to using NHWC + NNAPI_FLAG_USE_NCHW = 0x002, + + // Keep NNAPI_FLAG_MAX at the end of the enum definition + // And assign the last NNAPIFlag to it + NNAPI_FLAG_LAST = NNAPI_FLAG_USE_NCHW, +}; + #ifdef __cplusplus extern "C" { #endif -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options); +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options, unsigned long nnapi_flags); #ifdef __cplusplus } #endif - - diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 34188a21e4..b261a36708 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -790,7 +790,7 @@ public class OrtSession implements AutoCloseable { */ public void addNnapi() throws OrtException { checkClosed(); - addNnapi(OnnxRuntime.ortApiHandle, nativeHandle); + addNnapi(OnnxRuntime.ortApiHandle, nativeHandle, 0); } /** @@ -904,7 +904,8 @@ public class OrtSession implements AutoCloseable { private native void addTensorrt(long apiHandle, long nativeHandle, int deviceNum) throws OrtException; - private native void addNnapi(long apiHandle, long nativeHandle) throws OrtException; + private native void addNnapi(long apiHandle, long nativeHandle, long nnapiFlags) + throws OrtException; private native void addNuphar( long apiHandle, long nativeHandle, int allowUnalignedBuffers, String settings) diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 536a708ff5..bf4b63e1b7 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -410,15 +410,15 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addTen /* * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: addNnapi - * Signature: (J)V + * Signature: (JJJ)V */ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addNnapi - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong nnapiFlags) { (void)jobj; #ifdef USE_NNAPI - checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_Nnapi((OrtSessionOptions*) handle)); + checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_Nnapi((OrtSessionOptions*) handle, (unsigned long) nnapiFlags)); #else - (void)apiHandle;(void)handle; // Parameters used when NNAPI is defined. + (void)apiHandle;(void)handle;(void)nnapiFlags; // Parameters used when NNAPI is defined. throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with NNAPI support."); #endif } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index f7a5dbf8da..ca337cd3eb 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -13,8 +13,9 @@ namespace onnxruntime { constexpr const char* NNAPI = "Nnapi"; -NnapiExecutionProvider::NnapiExecutionProvider() - : IExecutionProvider{onnxruntime::kNnapiExecutionProvider} { +NnapiExecutionProvider::NnapiExecutionProvider(unsigned long nnapi_flags) + : IExecutionProvider{onnxruntime::kNnapiExecutionProvider}, + nnapi_flags_(nnapi_flags) { AllocatorCreationInfo device_info( [](int) { return onnxruntime::make_unique(OrtMemoryInfo(NNAPI, OrtAllocatorType::OrtDeviceAllocator)); @@ -224,8 +225,8 @@ common::Status NnapiExecutionProvider::Compile(const std::vector nnapi_model; ORT_RETURN_IF_ERROR(builder.Compile(nnapi_model)); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h index 1eded5e6d0..304e8ee5be 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h @@ -5,11 +5,13 @@ #include "core/framework/execution_provider.h" #include "core/providers/nnapi/nnapi_builtin/model.h" +#include "core/providers/nnapi/nnapi_provider_factory.h" namespace onnxruntime { + class NnapiExecutionProvider : public IExecutionProvider { public: - NnapiExecutionProvider(); + NnapiExecutionProvider(unsigned long nnapi_flags); virtual ~NnapiExecutionProvider(); std::vector> @@ -17,8 +19,13 @@ class NnapiExecutionProvider : public IExecutionProvider { const std::vector& /*kernel_registries*/) const override; common::Status Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) override; + unsigned long GetNNAPIFlags() const { return nnapi_flags_; } private: std::unordered_map> nnapi_models_; + + // The bit flags which define bool options for NNAPI EP, bits are defined as + // NNAPIFlags in include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h + const unsigned long nnapi_flags_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nnapi/nnapi_provider_factory.cc b/onnxruntime/core/providers/nnapi/nnapi_provider_factory.cc index 11f3bf67d8..10f7ce20cb 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_provider_factory.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_provider_factory.cc @@ -7,24 +7,25 @@ using namespace onnxruntime; namespace onnxruntime { - struct NnapiProviderFactory : IExecutionProviderFactory { - NnapiProviderFactory() {} + NnapiProviderFactory(unsigned long nnapi_flags) + : nnapi_flags_(nnapi_flags) {} ~NnapiProviderFactory() override {} std::unique_ptr CreateProvider() override; + unsigned long nnapi_flags_; }; std::unique_ptr NnapiProviderFactory::CreateProvider() { - return onnxruntime::make_unique(); + return onnxruntime::make_unique(nnapi_flags_); } -std::shared_ptr CreateExecutionProviderFactory_Nnapi() { - return std::make_shared(); +std::shared_ptr CreateExecutionProviderFactory_Nnapi(unsigned long nnapi_flags) { + return std::make_shared(nnapi_flags); } } // namespace onnxruntime -ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options) { - options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Nnapi()); +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options, unsigned long nnapi_flags) { + options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Nnapi(nnapi_flags)); return nullptr; } diff --git a/onnxruntime/test/framework/test_utils.cc b/onnxruntime/test/framework/test_utils.cc index 6b361aad9e..07c705f012 100644 --- a/onnxruntime/test/framework/test_utils.cc +++ b/onnxruntime/test/framework/test_utils.cc @@ -40,7 +40,7 @@ IExecutionProvider* TestOpenVINOExecutionProvider() { #ifdef USE_NNAPI IExecutionProvider* TestNnapiExecutionProvider() { - static NnapiExecutionProvider nnapi_provider; + static NnapiExecutionProvider nnapi_provider(0); return &nnapi_provider; } #endif diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index f67b54d607..bc68b25d45 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -364,7 +364,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { } if (enable_nnapi) { #ifdef USE_NNAPI - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(sf)); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(sf, 0)); #else fprintf(stderr, "NNAPI is not supported in this build"); return -1; diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 4868245752..b2cd4796e6 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -78,7 +78,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device #endif } else if (provider_name == onnxruntime::kNnapiExecutionProvider) { #ifdef USE_NNAPI - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options)); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, 0)); #else ORT_THROW("NNAPI is not supported in this build\n"); #endif diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc index 4432837756..f53ab7a244 100644 --- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc +++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc @@ -35,7 +35,7 @@ void RunAndVerifyOutputs(const std::string& model_file_name, run_options.run_tag = so.session_logid; InferenceSessionWrapper session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>())); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>(0))); ASSERT_STATUS_OK(session_object.Load(model_file_name)); ASSERT_STATUS_OK(session_object.Initialize()); @@ -141,5 +141,15 @@ TEST(NnapiExecutionProviderTest, FunctionTest) { RunAndVerifyOutputs(model_file_name, "NnapiExecutionProviderTest.FunctionTest", feeds, output_names, expected_dims_mul_m, expected_values_mul_m); } + +TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) { + unsigned long nnapi_flags = NNAPI_FLAG_USE_NONE; + nnapi_flags |= NNAPI_FLAG_USE_FP16; + onnxruntime::NnapiExecutionProvider nnapi_ep(nnapi_flags); + const auto flags = nnapi_ep.GetNNAPIFlags(); + ASSERT_TRUE(flags & NNAPI_FLAG_USE_FP16); + ASSERT_FALSE(flags & NNAPI_FLAG_USE_NCHW); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index e4af663a19..e44ef0ee16 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -19,7 +19,7 @@ std::shared_ptr CreateExecutionProviderFactory_Dnnl(i std::shared_ptr CreateExecutionProviderFactory_NGraph(const char* ng_backend_type); std::shared_ptr CreateExecutionProviderFactory_OpenVINO(const char* device_type, bool enable_vpu_fast_compile, const char* device_id, size_t num_of_threads); std::shared_ptr CreateExecutionProviderFactory_Nuphar(bool, const char*); -std::shared_ptr CreateExecutionProviderFactory_Nnapi(); +std::shared_ptr CreateExecutionProviderFactory_Nnapi(unsigned long); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(int device_id); std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id); @@ -96,7 +96,7 @@ std::unique_ptr DefaultNupharExecutionProvider(bool allow_un std::unique_ptr DefaultNnapiExecutionProvider() { #ifdef USE_NNAPI - return CreateExecutionProviderFactory_Nnapi()->CreateProvider(); + return CreateExecutionProviderFactory_Nnapi(0)->CreateProvider(); #else return nullptr; #endif