diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
index d3481d3287..93d572879b 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
@@ -430,7 +430,7 @@ namespace Microsoft.ML.OnnxRuntime
IntPtr /*(const OrtSession*)*/ session,
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(char**)*/profile_file);
- public static DOrtSessionEndProfiling OrtSessionEndProfiling;
+ public static DOrtSessionEndProfiling OrtSessionEndProfiling;
public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerName(
IntPtr /*(OrtSession*)*/ session,
@@ -567,7 +567,7 @@ namespace Microsoft.ML.OnnxRuntime
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id);
[DllImport(nativeLib, CharSet = charSet)]
- public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nnapi(IntPtr /*(OrtSessionOptions*)*/ options);
+ public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nnapi(IntPtr /*(OrtSessionOptions*)*/ options, ulong nnapi_flags);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nuphar(IntPtr /*(OrtSessionOptions*) */ options, int allow_unaligned_buffers, string settings);
@@ -898,7 +898,7 @@ namespace Microsoft.ML.OnnxRuntime
/// instance of OrtModelMetadata
/// instance of OrtAllocator
/// (output) producer name from the ModelMetadata instance
- public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
+ public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
public static DOrtModelMetadataGetProducerName OrtModelMetadataGetProducerName;
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
index 3b9cec2580..14fbea3ec5 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
@@ -155,9 +155,9 @@ namespace Microsoft.ML.OnnxRuntime
///
/// Use only if you have the onnxruntime package specific to this Execution Provider.
///
- public void AppendExecutionProvider_Nnapi()
+ public void AppendExecutionProvider_Nnapi(ulong nnapi_flags)
{
- NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Nnapi(handle));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Nnapi(handle, nnapi_flags));
}
///
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
index 899785bc8b..97981bde3f 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
@@ -97,7 +97,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
opt.AppendExecutionProvider_CUDA(0);
#endif
#if USE_DML
- // Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
+ // Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
// doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll
// The managed tests binary path already contains the right DirectML.dll, so use that
@@ -122,7 +122,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
opt.AppendExecutionProvider_MIGraphX(0);
#endif
#if USE_NNAPI
- opt.AppendExecutionProvider_Nnapi();
+ opt.AppendExecutionProvider_Nnapi(0);
#endif
@@ -1770,7 +1770,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
- // TestGpu() will test the CUDA EP on CUDA enabled builds and
+ // TestGpu() will test the CUDA EP on CUDA enabled builds and
// the DML EP on DML enabled builds
[GpuFact]
private void TestGpu()
@@ -1979,9 +1979,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
var dims = new long[] { 3, 2 };
var dataBuffer = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F };
var dataHandle = GCHandle.Alloc(dataBuffer, GCHandleType.Pinned);
-
+
try
- {
+ {
unsafe
{
float* p = (float*)dataHandle.AddrOfPinnedObject();
@@ -2354,7 +2354,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
#if USE_DML
- // Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
+ // Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
// doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll
// The managed tests binary path already contains the right DirectML.dll, so use that
diff --git a/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h b/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h
index d8b6a1ec27..e9f68dc28a 100644
--- a/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h
+++ b/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h
@@ -1,15 +1,34 @@
// Copyright 2019 JD.com Inc. JD AI
+#pragma once
#include "onnxruntime_c_api.h"
+// NNAPIFlags are bool options we want to set for NNAPI EP
+// This enum is defined as bit flats, and cannot have negative value
+// To generate a unsigned long nnapi_flags for using with OrtSessionOptionsAppendExecutionProvider_Nnapi below,
+// unsigned long nnapi_flags = 0;
+// nnapi_flags |= NNAPI_FLAG_USE_FP16;
+enum NNAPIFlags {
+ NNAPI_FLAG_USE_NONE = 0x000,
+
+ // Using fp16 relaxation in NNAPI EP, this may improve perf but may also reduce precision
+ NNAPI_FLAG_USE_FP16 = 0x001,
+
+ // Use NCHW layout in NNAPI EP, this is only available after Android API level 29
+ // Please note for now, NNAPI perform worse using NCHW compare to using NHWC
+ NNAPI_FLAG_USE_NCHW = 0x002,
+
+ // Keep NNAPI_FLAG_MAX at the end of the enum definition
+ // And assign the last NNAPIFlag to it
+ NNAPI_FLAG_LAST = NNAPI_FLAG_USE_NCHW,
+};
+
#ifdef __cplusplus
extern "C" {
#endif
-ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options);
+ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options, unsigned long nnapi_flags);
#ifdef __cplusplus
}
#endif
-
-
diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java
index 34188a21e4..b261a36708 100644
--- a/java/src/main/java/ai/onnxruntime/OrtSession.java
+++ b/java/src/main/java/ai/onnxruntime/OrtSession.java
@@ -790,7 +790,7 @@ public class OrtSession implements AutoCloseable {
*/
public void addNnapi() throws OrtException {
checkClosed();
- addNnapi(OnnxRuntime.ortApiHandle, nativeHandle);
+ addNnapi(OnnxRuntime.ortApiHandle, nativeHandle, 0);
}
/**
@@ -904,7 +904,8 @@ public class OrtSession implements AutoCloseable {
private native void addTensorrt(long apiHandle, long nativeHandle, int deviceNum)
throws OrtException;
- private native void addNnapi(long apiHandle, long nativeHandle) throws OrtException;
+ private native void addNnapi(long apiHandle, long nativeHandle, long nnapiFlags)
+ throws OrtException;
private native void addNuphar(
long apiHandle, long nativeHandle, int allowUnalignedBuffers, String settings)
diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
index 536a708ff5..bf4b63e1b7 100644
--- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
+++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
@@ -410,15 +410,15 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addTen
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: addNnapi
- * Signature: (J)V
+ * Signature: (JJJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addNnapi
- (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
+ (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong nnapiFlags) {
(void)jobj;
#ifdef USE_NNAPI
- checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_Nnapi((OrtSessionOptions*) handle));
+ checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_Nnapi((OrtSessionOptions*) handle, (unsigned long) nnapiFlags));
#else
- (void)apiHandle;(void)handle; // Parameters used when NNAPI is defined.
+ (void)apiHandle;(void)handle;(void)nnapiFlags; // Parameters used when NNAPI is defined.
throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with NNAPI support.");
#endif
}
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc
index f7a5dbf8da..ca337cd3eb 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc
@@ -13,8 +13,9 @@ namespace onnxruntime {
constexpr const char* NNAPI = "Nnapi";
-NnapiExecutionProvider::NnapiExecutionProvider()
- : IExecutionProvider{onnxruntime::kNnapiExecutionProvider} {
+NnapiExecutionProvider::NnapiExecutionProvider(unsigned long nnapi_flags)
+ : IExecutionProvider{onnxruntime::kNnapiExecutionProvider},
+ nnapi_flags_(nnapi_flags) {
AllocatorCreationInfo device_info(
[](int) {
return onnxruntime::make_unique(OrtMemoryInfo(NNAPI, OrtAllocatorType::OrtDeviceAllocator));
@@ -224,8 +225,8 @@ common::Status NnapiExecutionProvider::Compile(const std::vector nnapi_model;
ORT_RETURN_IF_ERROR(builder.Compile(nnapi_model));
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h
index 1eded5e6d0..304e8ee5be 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h
@@ -5,11 +5,13 @@
#include "core/framework/execution_provider.h"
#include "core/providers/nnapi/nnapi_builtin/model.h"
+#include "core/providers/nnapi/nnapi_provider_factory.h"
namespace onnxruntime {
+
class NnapiExecutionProvider : public IExecutionProvider {
public:
- NnapiExecutionProvider();
+ NnapiExecutionProvider(unsigned long nnapi_flags);
virtual ~NnapiExecutionProvider();
std::vector>
@@ -17,8 +19,13 @@ class NnapiExecutionProvider : public IExecutionProvider {
const std::vector& /*kernel_registries*/) const override;
common::Status Compile(const std::vector& fused_nodes,
std::vector& node_compute_funcs) override;
+ unsigned long GetNNAPIFlags() const { return nnapi_flags_; }
private:
std::unordered_map> nnapi_models_;
+
+ // The bit flags which define bool options for NNAPI EP, bits are defined as
+ // NNAPIFlags in include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h
+ const unsigned long nnapi_flags_;
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/nnapi/nnapi_provider_factory.cc b/onnxruntime/core/providers/nnapi/nnapi_provider_factory.cc
index 11f3bf67d8..10f7ce20cb 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_provider_factory.cc
+++ b/onnxruntime/core/providers/nnapi/nnapi_provider_factory.cc
@@ -7,24 +7,25 @@
using namespace onnxruntime;
namespace onnxruntime {
-
struct NnapiProviderFactory : IExecutionProviderFactory {
- NnapiProviderFactory() {}
+ NnapiProviderFactory(unsigned long nnapi_flags)
+ : nnapi_flags_(nnapi_flags) {}
~NnapiProviderFactory() override {}
std::unique_ptr CreateProvider() override;
+ unsigned long nnapi_flags_;
};
std::unique_ptr NnapiProviderFactory::CreateProvider() {
- return onnxruntime::make_unique();
+ return onnxruntime::make_unique(nnapi_flags_);
}
-std::shared_ptr CreateExecutionProviderFactory_Nnapi() {
- return std::make_shared();
+std::shared_ptr CreateExecutionProviderFactory_Nnapi(unsigned long nnapi_flags) {
+ return std::make_shared(nnapi_flags);
}
} // namespace onnxruntime
-ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options) {
- options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Nnapi());
+ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options, unsigned long nnapi_flags) {
+ options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Nnapi(nnapi_flags));
return nullptr;
}
diff --git a/onnxruntime/test/framework/test_utils.cc b/onnxruntime/test/framework/test_utils.cc
index 6b361aad9e..07c705f012 100644
--- a/onnxruntime/test/framework/test_utils.cc
+++ b/onnxruntime/test/framework/test_utils.cc
@@ -40,7 +40,7 @@ IExecutionProvider* TestOpenVINOExecutionProvider() {
#ifdef USE_NNAPI
IExecutionProvider* TestNnapiExecutionProvider() {
- static NnapiExecutionProvider nnapi_provider;
+ static NnapiExecutionProvider nnapi_provider(0);
return &nnapi_provider;
}
#endif
diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc
index f67b54d607..bc68b25d45 100644
--- a/onnxruntime/test/onnx/main.cc
+++ b/onnxruntime/test/onnx/main.cc
@@ -364,7 +364,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_nnapi) {
#ifdef USE_NNAPI
- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(sf));
+ Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(sf, 0));
#else
fprintf(stderr, "NNAPI is not supported in this build");
return -1;
diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc
index 4868245752..b2cd4796e6 100644
--- a/onnxruntime/test/perftest/ort_test_session.cc
+++ b/onnxruntime/test/perftest/ort_test_session.cc
@@ -78,7 +78,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
#endif
} else if (provider_name == onnxruntime::kNnapiExecutionProvider) {
#ifdef USE_NNAPI
- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options));
+ Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, 0));
#else
ORT_THROW("NNAPI is not supported in this build\n");
#endif
diff --git a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc
index 4432837756..f53ab7a244 100644
--- a/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc
+++ b/onnxruntime/test/providers/nnapi/nnapi_basic_test.cc
@@ -35,7 +35,7 @@ void RunAndVerifyOutputs(const std::string& model_file_name,
run_options.run_tag = so.session_logid;
InferenceSessionWrapper session_object{so, GetEnvironment()};
- ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>()));
+ ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>(0)));
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
@@ -141,5 +141,15 @@ TEST(NnapiExecutionProviderTest, FunctionTest) {
RunAndVerifyOutputs(model_file_name, "NnapiExecutionProviderTest.FunctionTest", feeds, output_names, expected_dims_mul_m, expected_values_mul_m);
}
+
+TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {
+ unsigned long nnapi_flags = NNAPI_FLAG_USE_NONE;
+ nnapi_flags |= NNAPI_FLAG_USE_FP16;
+ onnxruntime::NnapiExecutionProvider nnapi_ep(nnapi_flags);
+ const auto flags = nnapi_ep.GetNNAPIFlags();
+ ASSERT_TRUE(flags & NNAPI_FLAG_USE_FP16);
+ ASSERT_FALSE(flags & NNAPI_FLAG_USE_NCHW);
+}
+
} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc
index e4af663a19..e44ef0ee16 100644
--- a/onnxruntime/test/util/default_providers.cc
+++ b/onnxruntime/test/util/default_providers.cc
@@ -19,7 +19,7 @@ std::shared_ptr CreateExecutionProviderFactory_Dnnl(i
std::shared_ptr CreateExecutionProviderFactory_NGraph(const char* ng_backend_type);
std::shared_ptr CreateExecutionProviderFactory_OpenVINO(const char* device_type, bool enable_vpu_fast_compile, const char* device_id, size_t num_of_threads);
std::shared_ptr CreateExecutionProviderFactory_Nuphar(bool, const char*);
-std::shared_ptr CreateExecutionProviderFactory_Nnapi();
+std::shared_ptr CreateExecutionProviderFactory_Nnapi(unsigned long);
std::shared_ptr CreateExecutionProviderFactory_Rknpu();
std::shared_ptr CreateExecutionProviderFactory_Tensorrt(int device_id);
std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id);
@@ -96,7 +96,7 @@ std::unique_ptr DefaultNupharExecutionProvider(bool allow_un
std::unique_ptr DefaultNnapiExecutionProvider() {
#ifdef USE_NNAPI
- return CreateExecutionProviderFactory_Nnapi()->CreateProvider();
+ return CreateExecutionProviderFactory_Nnapi(0)->CreateProvider();
#else
return nullptr;
#endif