mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add runtime options for NNAPI EP (#5576)
* Add options for nnapi ep * Add nnapi flags test * add comments * Add flag comments * Make the flags bitset const * Fix build break * Add stub changes to java and c# api * Fix java related build break * Fix java build break * Switch to bit flags instead of bitset
This commit is contained in:
parent
2ad7bcb766
commit
a2b551ff08
14 changed files with 77 additions and 38 deletions
|
|
@ -430,7 +430,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
IntPtr /*(const OrtSession*)*/ session,
|
||||
IntPtr /*(OrtAllocator*)*/ allocator,
|
||||
out IntPtr /*(char**)*/profile_file);
|
||||
public static DOrtSessionEndProfiling OrtSessionEndProfiling;
|
||||
public static DOrtSessionEndProfiling OrtSessionEndProfiling;
|
||||
|
||||
public delegate IntPtr /*(OrtStatus*)*/DOrtSessionGetOverridableInitializerName(
|
||||
IntPtr /*(OrtSession*)*/ session,
|
||||
|
|
@ -567,7 +567,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nnapi(IntPtr /*(OrtSessionOptions*)*/ options);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nnapi(IntPtr /*(OrtSessionOptions*)*/ options, ulong nnapi_flags);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nuphar(IntPtr /*(OrtSessionOptions*) */ options, int allow_unaligned_buffers, string settings);
|
||||
|
|
@ -898,7 +898,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// <param name="modelMetadata">instance of OrtModelMetadata</param>
|
||||
/// <param name="allocator">instance of OrtAllocator</param>
|
||||
/// <param name="value">(output) producer name from the ModelMetadata instance</param>
|
||||
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
|
||||
public delegate IntPtr /* (OrtStatus*) */ DOrtModelMetadataGetProducerName(IntPtr /* (const OrtModelMetadata*) */ modelMetadata,
|
||||
IntPtr /* (OrtAllocator*) */ allocator, out IntPtr /* (char**) */ value);
|
||||
public static DOrtModelMetadataGetProducerName OrtModelMetadataGetProducerName;
|
||||
|
||||
|
|
|
|||
|
|
@ -155,9 +155,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// <summary>
|
||||
/// Use only if you have the onnxruntime package specific to this Execution Provider.
|
||||
/// </summary>
|
||||
public void AppendExecutionProvider_Nnapi()
|
||||
public void AppendExecutionProvider_Nnapi(ulong nnapi_flags)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Nnapi(handle));
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Nnapi(handle, nnapi_flags));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
opt.AppendExecutionProvider_CUDA(0);
|
||||
#endif
|
||||
#if USE_DML
|
||||
// Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
|
||||
// Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
|
||||
// doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll
|
||||
// The managed tests binary path already contains the right DirectML.dll, so use that
|
||||
|
||||
|
|
@ -122,7 +122,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
opt.AppendExecutionProvider_MIGraphX(0);
|
||||
#endif
|
||||
#if USE_NNAPI
|
||||
opt.AppendExecutionProvider_Nnapi();
|
||||
opt.AppendExecutionProvider_Nnapi(0);
|
||||
#endif
|
||||
|
||||
|
||||
|
|
@ -1770,7 +1770,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
}
|
||||
|
||||
// TestGpu() will test the CUDA EP on CUDA enabled builds and
|
||||
// TestGpu() will test the CUDA EP on CUDA enabled builds and
|
||||
// the DML EP on DML enabled builds
|
||||
[GpuFact]
|
||||
private void TestGpu()
|
||||
|
|
@ -1979,9 +1979,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
var dims = new long[] { 3, 2 };
|
||||
var dataBuffer = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F };
|
||||
var dataHandle = GCHandle.Alloc(dataBuffer, GCHandleType.Pinned);
|
||||
|
||||
|
||||
try
|
||||
{
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
float* p = (float*)dataHandle.AddrOfPinnedObject();
|
||||
|
|
@ -2354,7 +2354,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
{
|
||||
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
|
||||
#if USE_DML
|
||||
// Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
|
||||
// Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
|
||||
// doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll
|
||||
// The managed tests binary path already contains the right DirectML.dll, so use that
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,34 @@
|
|||
// Copyright 2019 JD.com Inc. JD AI
|
||||
#pragma once
|
||||
|
||||
#include "onnxruntime_c_api.h"
|
||||
|
||||
// NNAPIFlags are bool options we want to set for NNAPI EP
|
||||
// This enum is defined as bit flats, and cannot have negative value
|
||||
// To generate a unsigned long nnapi_flags for using with OrtSessionOptionsAppendExecutionProvider_Nnapi below,
|
||||
// unsigned long nnapi_flags = 0;
|
||||
// nnapi_flags |= NNAPI_FLAG_USE_FP16;
|
||||
enum NNAPIFlags {
|
||||
NNAPI_FLAG_USE_NONE = 0x000,
|
||||
|
||||
// Using fp16 relaxation in NNAPI EP, this may improve perf but may also reduce precision
|
||||
NNAPI_FLAG_USE_FP16 = 0x001,
|
||||
|
||||
// Use NCHW layout in NNAPI EP, this is only available after Android API level 29
|
||||
// Please note for now, NNAPI perform worse using NCHW compare to using NHWC
|
||||
NNAPI_FLAG_USE_NCHW = 0x002,
|
||||
|
||||
// Keep NNAPI_FLAG_MAX at the end of the enum definition
|
||||
// And assign the last NNAPIFlag to it
|
||||
NNAPI_FLAG_LAST = NNAPI_FLAG_USE_NCHW,
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options);
|
||||
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options, unsigned long nnapi_flags);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -790,7 +790,7 @@ public class OrtSession implements AutoCloseable {
|
|||
*/
|
||||
public void addNnapi() throws OrtException {
|
||||
checkClosed();
|
||||
addNnapi(OnnxRuntime.ortApiHandle, nativeHandle);
|
||||
addNnapi(OnnxRuntime.ortApiHandle, nativeHandle, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -904,7 +904,8 @@ public class OrtSession implements AutoCloseable {
|
|||
private native void addTensorrt(long apiHandle, long nativeHandle, int deviceNum)
|
||||
throws OrtException;
|
||||
|
||||
private native void addNnapi(long apiHandle, long nativeHandle) throws OrtException;
|
||||
private native void addNnapi(long apiHandle, long nativeHandle, long nnapiFlags)
|
||||
throws OrtException;
|
||||
|
||||
private native void addNuphar(
|
||||
long apiHandle, long nativeHandle, int allowUnalignedBuffers, String settings)
|
||||
|
|
|
|||
|
|
@ -410,15 +410,15 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addTen
|
|||
/*
|
||||
* Class: ai_onnxruntime_OrtSession_SessionOptions
|
||||
* Method: addNnapi
|
||||
* Signature: (J)V
|
||||
* Signature: (JJJ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addNnapi
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
|
||||
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong nnapiFlags) {
|
||||
(void)jobj;
|
||||
#ifdef USE_NNAPI
|
||||
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_Nnapi((OrtSessionOptions*) handle));
|
||||
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_Nnapi((OrtSessionOptions*) handle, (unsigned long) nnapiFlags));
|
||||
#else
|
||||
(void)apiHandle;(void)handle; // Parameters used when NNAPI is defined.
|
||||
(void)apiHandle;(void)handle;(void)nnapiFlags; // Parameters used when NNAPI is defined.
|
||||
throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with NNAPI support.");
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ namespace onnxruntime {
|
|||
|
||||
constexpr const char* NNAPI = "Nnapi";
|
||||
|
||||
NnapiExecutionProvider::NnapiExecutionProvider()
|
||||
: IExecutionProvider{onnxruntime::kNnapiExecutionProvider} {
|
||||
NnapiExecutionProvider::NnapiExecutionProvider(unsigned long nnapi_flags)
|
||||
: IExecutionProvider{onnxruntime::kNnapiExecutionProvider},
|
||||
nnapi_flags_(nnapi_flags) {
|
||||
AllocatorCreationInfo device_info(
|
||||
[](int) {
|
||||
return onnxruntime::make_unique<CPUAllocator>(OrtMemoryInfo(NNAPI, OrtAllocatorType::OrtDeviceAllocator));
|
||||
|
|
@ -224,8 +225,8 @@ common::Status NnapiExecutionProvider::Compile(const std::vector<onnxruntime::No
|
|||
{
|
||||
onnxruntime::GraphViewer graph_viewer(graph_body);
|
||||
nnapi::ModelBuilder builder(graph_viewer);
|
||||
builder.SetUseNCHW(false);
|
||||
builder.SetUseFp16(false);
|
||||
builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW);
|
||||
builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16);
|
||||
std::unique_ptr<nnapi::Model> nnapi_model;
|
||||
ORT_RETURN_IF_ERROR(builder.Compile(nnapi_model));
|
||||
|
||||
|
|
|
|||
|
|
@ -5,11 +5,13 @@
|
|||
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/providers/nnapi/nnapi_builtin/model.h"
|
||||
#include "core/providers/nnapi/nnapi_provider_factory.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class NnapiExecutionProvider : public IExecutionProvider {
|
||||
public:
|
||||
NnapiExecutionProvider();
|
||||
NnapiExecutionProvider(unsigned long nnapi_flags);
|
||||
virtual ~NnapiExecutionProvider();
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
|
|
@ -17,8 +19,13 @@ class NnapiExecutionProvider : public IExecutionProvider {
|
|||
const std::vector<const KernelRegistry*>& /*kernel_registries*/) const override;
|
||||
common::Status Compile(const std::vector<onnxruntime::Node*>& fused_nodes,
|
||||
std::vector<NodeComputeInfo>& node_compute_funcs) override;
|
||||
unsigned long GetNNAPIFlags() const { return nnapi_flags_; }
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::unique_ptr<onnxruntime::nnapi::Model>> nnapi_models_;
|
||||
|
||||
// The bit flags which define bool options for NNAPI EP, bits are defined as
|
||||
// NNAPIFlags in include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h
|
||||
const unsigned long nnapi_flags_;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -7,24 +7,25 @@
|
|||
using namespace onnxruntime;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
struct NnapiProviderFactory : IExecutionProviderFactory {
|
||||
NnapiProviderFactory() {}
|
||||
NnapiProviderFactory(unsigned long nnapi_flags)
|
||||
: nnapi_flags_(nnapi_flags) {}
|
||||
~NnapiProviderFactory() override {}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> CreateProvider() override;
|
||||
unsigned long nnapi_flags_;
|
||||
};
|
||||
|
||||
std::unique_ptr<IExecutionProvider> NnapiProviderFactory::CreateProvider() {
|
||||
return onnxruntime::make_unique<NnapiExecutionProvider>();
|
||||
return onnxruntime::make_unique<NnapiExecutionProvider>(nnapi_flags_);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi() {
|
||||
return std::make_shared<onnxruntime::NnapiProviderFactory>();
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(unsigned long nnapi_flags) {
|
||||
return std::make_shared<onnxruntime::NnapiProviderFactory>(nnapi_flags);
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options) {
|
||||
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Nnapi());
|
||||
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options, unsigned long nnapi_flags) {
|
||||
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Nnapi(nnapi_flags));
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ IExecutionProvider* TestOpenVINOExecutionProvider() {
|
|||
|
||||
#ifdef USE_NNAPI
|
||||
IExecutionProvider* TestNnapiExecutionProvider() {
|
||||
static NnapiExecutionProvider nnapi_provider;
|
||||
static NnapiExecutionProvider nnapi_provider(0);
|
||||
return &nnapi_provider;
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -364,7 +364,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
}
|
||||
if (enable_nnapi) {
|
||||
#ifdef USE_NNAPI
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(sf));
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(sf, 0));
|
||||
#else
|
||||
fprintf(stderr, "NNAPI is not supported in this build");
|
||||
return -1;
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
#endif
|
||||
} else if (provider_name == onnxruntime::kNnapiExecutionProvider) {
|
||||
#ifdef USE_NNAPI
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options));
|
||||
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, 0));
|
||||
#else
|
||||
ORT_THROW("NNAPI is not supported in this build\n");
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ void RunAndVerifyOutputs(const std::string& model_file_name,
|
|||
run_options.run_tag = so.session_logid;
|
||||
|
||||
InferenceSessionWrapper session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>()));
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(onnxruntime::make_unique<::onnxruntime::NnapiExecutionProvider>(0)));
|
||||
ASSERT_STATUS_OK(session_object.Load(model_file_name));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
|
||||
|
|
@ -141,5 +141,15 @@ TEST(NnapiExecutionProviderTest, FunctionTest) {
|
|||
|
||||
RunAndVerifyOutputs(model_file_name, "NnapiExecutionProviderTest.FunctionTest", feeds, output_names, expected_dims_mul_m, expected_values_mul_m);
|
||||
}
|
||||
|
||||
TEST(NnapiExecutionProviderTest, NNAPIFlagsTest) {
|
||||
unsigned long nnapi_flags = NNAPI_FLAG_USE_NONE;
|
||||
nnapi_flags |= NNAPI_FLAG_USE_FP16;
|
||||
onnxruntime::NnapiExecutionProvider nnapi_ep(nnapi_flags);
|
||||
const auto flags = nnapi_ep.GetNNAPIFlags();
|
||||
ASSERT_TRUE(flags & NNAPI_FLAG_USE_FP16);
|
||||
ASSERT_FALSE(flags & NNAPI_FLAG_USE_NCHW);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Dnnl(i
|
|||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_NGraph(const char* ng_backend_type);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_OpenVINO(const char* device_type, bool enable_vpu_fast_compile, const char* device_id, size_t num_of_threads);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nuphar(bool, const char*);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi();
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(unsigned long);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Rknpu();
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Tensorrt(int device_id);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_MIGraphX(int device_id);
|
||||
|
|
@ -96,7 +96,7 @@ std::unique_ptr<IExecutionProvider> DefaultNupharExecutionProvider(bool allow_un
|
|||
|
||||
std::unique_ptr<IExecutionProvider> DefaultNnapiExecutionProvider() {
|
||||
#ifdef USE_NNAPI
|
||||
return CreateExecutionProviderFactory_Nnapi()->CreateProvider();
|
||||
return CreateExecutionProviderFactory_Nnapi(0)->CreateProvider();
|
||||
#else
|
||||
return nullptr;
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue