Add runtime options for NNAPI EP (#5576)

* Add options for nnapi ep

* Add nnapi flags test

* add comments

* Add flag comments

* Make the flags bitset const

* Fix build break

* Add stub changes to java and c# api

* Fix java related build break

* Fix java build break

* Switch to bit flags instead of bitset
This commit is contained in:
Guoyu Wang 2020-11-04 10:08:43 -08:00 committed by GitHub
parent 2ad7bcb766
commit a2b551ff08
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 77 additions and 38 deletions

View file

@ -430,7 +430,7 @@ namespace Microsoft.ML.OnnxRuntime
IntPtr /*(const OrtSession*)*/ session,
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(char**)*/profile_file);
public static DOrtSessionEndProfiling OrtSessionEndProfiling;
public static DOrtSessionEndProfiling OrtSessionEndProfiling;
public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerName(
IntPtr /*(OrtSession*)*/ session,
@ -567,7 +567,7 @@ namespace Microsoft.ML.OnnxRuntime
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nnapi(IntPtr /*(OrtSessionOptions*)*/ options);
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nnapi(IntPtr /*(OrtSessionOptions*)*/ options, ulong nnapi_flags);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nuphar(IntPtr /*(OrtSessionOptions*) */ options, int allow_unaligned_buffers, string settings);
@ -898,7 +898,7 @@ namespace Microsoft.ML.OnnxRuntime
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
/// <param name="allocator">instance of OrtAllocator</param>
/// <param name="value">(output) producer name from the ModelMetadata instance</param>
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
public static DOrtModelMetadataGetProducerName OrtModelMetadataGetProducerName;

View file

@ -155,9 +155,9 @@ namespace Microsoft.ML.OnnxRuntime
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
public void AppendExecutionProvider_Nnapi()
public void AppendExecutionProvider_Nnapi(ulong nnapi_flags)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Nnapi(handle));
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Nnapi(handle, nnapi_flags));
}
/// <summary>

View file

@ -97,7 +97,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
opt.AppendExecutionProvider_CUDA(0);
#endif
#if USE_DML
// Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
// Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
// doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll
// The managed tests binary path already contains the right DirectML.dll, so use that
@ -122,7 +122,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
opt.AppendExecutionProvider_MIGraphX(0);
#endif
#if USE_NNAPI
opt.AppendExecutionProvider_Nnapi();
opt.AppendExecutionProvider_Nnapi(0);
#endif
@ -1770,7 +1770,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
}
// TestGpu() will test the CUDA EP on CUDA enabled builds and
// TestGpu() will test the CUDA EP on CUDA enabled builds and
// the DML EP on DML enabled builds
[GpuFact]
private void TestGpu()
@ -1979,9 +1979,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
var dims = new long[] { 3, 2 };
var dataBuffer = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F };
var dataHandle = GCHandle.Alloc(dataBuffer, GCHandleType.Pinned);
try
{
{
unsafe
{
float* p = (float*)dataHandle.AddrOfPinnedObject();
@ -2354,7 +2354,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
#if USE_DML
// Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
// Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
// doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll
// The managed tests binary path already contains the right DirectML.dll, so use that

View file

@ -1,15 +1,34 @@
// Copyright 2019 JD.com Inc. JD AI
#pragma once
#include "onnxruntime_c_api.h"
// NNAPIFlags are bool options we want to set for NNAPI EP
// This enum is defined as bit flats, and cannot have negative value
// To generate a unsigned long nnapi_flags for using with OrtSessionOptionsAppendExecutionProvider_Nnapi below,
// unsigned long nnapi_flags = 0;
// nnapi_flags |= NNAPI_FLAG_USE_FP16;
enum NNAPIFlags {
NNAPI_FLAG_USE_NONE = 0x000,
// Using fp16 relaxation in NNAPI EP, this may improve perf but may also reduce precision
NNAPI_FLAG_USE_FP16 = 0x001,
// Use NCHW layout in NNAPI EP, this is only available after Android API level 29
// Please note for now, NNAPI perform worse using NCHW compare to using NHWC
NNAPI_FLAG_USE_NCHW = 0x002,
// Keep NNAPI_FLAG_MAX at the end of the enum definition
// And assign the last NNAPIFlag to it
NNAPI_FLAG_LAST = NNAPI_FLAG_USE_NCHW,
};
#ifdef __cplusplus
extern "C" {
#endif
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options);
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options, unsigned long nnapi_flags);
#ifdef __cplusplus
}
#endif

View file

@ -790,7 +790,7 @@ public class OrtSession implements AutoCloseable {
*/
public void addNnapi() throws OrtException {
checkClosed();
addNnapi(OnnxRuntime.ortApiHandle, nativeHandle);
addNnapi(OnnxRuntime.ortApiHandle, nativeHandle, 0);
}
/**
@ -904,7 +904,8 @@ public class OrtSession implements AutoCloseable {
private native void addTensorrt(long apiHandle, long nativeHandle, int deviceNum)
throws OrtException;
private native void addNnapi(long apiHandle, long nativeHandle) throws OrtException;
private native void addNnapi(long apiHandle, long nativeHandle, long nnapiFlags)
throws OrtException;
private native void addNuphar(
long apiHandle, long nativeHandle, int allowUnalignedBuffers, String settings)

View file

@ -410,15 +410,15 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addTen
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: addNnapi
* Signature: (J)V
* Signature: (JJJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addNnapi
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong nnapiFlags) {
(void)jobj;
#ifdef USE_NNAPI
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_Nnapi((OrtSessionOptions*) handle));
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_Nnapi((OrtSessionOptions*) handle, (unsigned long) nnapiFlags));
#else
(void)apiHandle;(void)handle; // Parameters used when NNAPI is defined.
(void)apiHandle;(void)handle;(void)nnapiFlags; // Parameters used when NNAPI is defined.
throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with NNAPI support.");
#endif
}

View file

@ -13,8 +13,9 @@ namespace onnxruntime {
constexpr const char* NNAPI = "Nnapi";
NnapiExecutionProvider::NnapiExecutionProvider()
: IExecutionProvider{onnxruntime::kNnapiExecutionProvider} {
NnapiExecutionProvider::NnapiExecutionProvider(unsigned long nnapi_flags)
: IExecutionProvider{onnxruntime::kNnapiExecutionProvider},
nnapi_flags_(nnapi_flags) {
AllocatorCreationInfo device_info(
[](int) {
return onnxruntime::make_unique<CPUAllocator>(OrtMemoryInfo(NNAPI, OrtAllocatorType::OrtDeviceAllocator));
@ -224,8 +225,8 @@ common::Status NnapiExecutionProvider::Compile(const std::vector<onnxruntime::No
{
onnxruntime::GraphViewer graph_viewer(graph_body);
nnapi::ModelBuilder builder(graph_viewer);
builder.SetUseNCHW(false);
builder.SetUseFp16(false);
builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW);
builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16);
std::unique_ptr<nnapi::Model> nnapi_model;
ORT_RETURN_IF_ERROR(builder.Compile(nnapi_model));

View file

@ -5,11 +5,13 @@
#include "core/framework/execution_provider.h"
#include "core/providers/nnapi/nnapi_builtin/model.h"
#include "core/providers/nnapi/nnapi_provider_factory.h"
namespace onnxruntime {
class NnapiExecutionProvider : public IExecutionProvider {
public:
NnapiExecutionProvider();
NnapiExecutionProvider(unsigned long nnapi_flags);
virtual ~NnapiExecutionProvider();
std::vector<std::unique_ptr<ComputeCapability>>
@ -17,8 +19,13 @@ class NnapiExecutionProvider : public IExecutionProvider {
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;
unsigned long GetNNAPIFlags() const { return nnapi_flags_; }
private:
std::unordered_map<std::string, std::unique_ptr<onnxruntime::nnapi::Model>> nnapi_models_;
// The bit flags which define bool options for NNAPI EP, bits are defined as
// NNAPIFlags in include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h
const unsigned long nnapi_flags_;
};
} // namespace onnxruntime

View file

@ -7,24 +7,25 @@
using namespace onnxruntime;
namespace onnxruntime {
struct NnapiProviderFactory : IExecutionProviderFactory {
NnapiProviderFactory() {}
NnapiProviderFactory(unsigned long nnapi_flags)
: nnapi_flags_(nnapi_flags) {}
~NnapiProviderFactory() override {}
std::unique_ptr<IExecutionProvider> CreateProvider() override;
unsigned long nnapi_flags_;
};
std::unique_ptr<IExecutionProvider> NnapiProviderFactory::CreateProvider() {
return onnxruntime::make_unique<NnapiExecutionProvider>();
return onnxruntime::make_unique<NnapiExecutionProvider>(nnapi_flags_);
}
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi() {
return std::make_shared<onnxruntime::NnapiProviderFactory>();
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(unsigned long nnapi_flags) {
return std::make_shared<onnxruntime::NnapiProviderFactory>(nnapi_flags);
}
} // namespace onnxruntime
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options) {
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Nnapi());
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options, unsigned long nnapi_flags) {
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Nnapi(nnapi_flags));
return nullptr;
}

View file

@ -40,7 +40,7 @@ IExecutionProvider* TestOpenVINOExecutionProvider() {
#ifdef USE_NNAPI
IExecutionProvider* TestNnapiExecutionProvider() {
static NnapiExecutionProvider nnapi_provider;
static NnapiExecutionProvider nnapi_provider(0);
return &nnapi_provider;
}
#endif

View file

@ -364,7 +364,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
}
if (enable_nnapi) {
#ifdef USE_NNAPI
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(sf));
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(sf, 0));
#else
fprintf(stderr, "NNAPI is not supported in this build");
return -1;

View file

@ -78,7 +78,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
#endif
} else if (provider_name == onnxruntime::kNnapiExecutionProvider) {
#ifdef USE_NNAPI
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options));
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, 0));
#else
ORT_THROW("NNAPI is not supported in this build\n");
#endif

View file

@ -35,7 +35,7 @@ void RunAndVerifyOutputs(const std::string& model_file_name,
run_options.run_tag = so.session_logid;
InferenceSessionWrapper session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>()));
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>(0)));
ASSERT_STATUS_OK(session_object.Load(model_file_name));
ASSERT_STATUS_OK(session_object.Initialize());
@ -141,5 +141,15 @@ TEST(NnapiExecutionProviderTest, FunctionTest) {
RunAndVerifyOutputs(model_file_name, "NnapiExecutionProviderTest.FunctionTest", feeds, output_names, expected_dims_mul_m, expected_values_mul_m);
}
TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {
unsigned long nnapi_flags = NNAPI_FLAG_USE_NONE;
nnapi_flags |= NNAPI_FLAG_USE_FP16;
onnxruntime::NnapiExecutionProvider nnapi_ep(nnapi_flags);
const auto flags = nnapi_ep.GetNNAPIFlags();
ASSERT_TRUE(flags & NNAPI_FLAG_USE_FP16);
ASSERT_FALSE(flags & NNAPI_FLAG_USE_NCHW);
}
} // namespace test
} // namespace onnxruntime

View file

@ -19,7 +19,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(i
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_NGraph(const char* ng_backend_type);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_OpenVINO(const char* device_type, bool enable_vpu_fast_compile, const char* device_id, size_t num_of_threads);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar(bool, const char*);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi();
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(unsigned long);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rknpu();
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(int device_id);
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_MIGraphX(int device_id);
@ -96,7 +96,7 @@ std::unique_ptr<IExecutionProvider> DefaultNupharExecutionProvider(bool allow_un
std::unique_ptr<IExecutionProvider> DefaultNnapiExecutionProvider() {
#ifdef USE_NNAPI
return CreateExecutionProviderFactory_Nnapi()->CreateProvider();
return CreateExecutionProviderFactory_Nnapi(0)->CreateProvider();
#else
return nullptr;
#endif