mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add NNAPI to providers that can be used via the python bindings. (#5867)
Update ORT model conversion script - add args for specifying optimization level and whether to use NNAPI - add logic to create a list of required ops and ORT format model that can be used with NNAPI
This commit is contained in:
parent
3970eb2e5d
commit
f0142da59c
3 changed files with 86 additions and 10 deletions
|
|
@ -392,6 +392,15 @@ if (onnxruntime_USE_DML)
|
|||
)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_NNAPI_BUILTIN)
|
||||
add_custom_command(
|
||||
TARGET onnxruntime_pybind11_state POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
$<TARGET_FILE:onnxruntime_providers_nnapi>
|
||||
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/capi/
|
||||
)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS)
|
||||
include(onnxruntime_language_interop_ops.cmake)
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -207,6 +207,7 @@ std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_VITISA
|
|||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ACL(int use_arena);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_ArmNN(int use_arena);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_DML(int device_id);
|
||||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(uint32_t flags);
|
||||
} // namespace onnxruntime
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
|
@ -437,8 +438,10 @@ static const std::vector<std::string>& GetAllProviders() {
|
|||
static std::vector<std::string> all_providers = {kTensorrtExecutionProvider, kCudaExecutionProvider,
|
||||
kMIGraphXExecutionProvider, kRocmExecutionProvider,
|
||||
kOpenVINOExecutionProvider, kDnnlExecutionProvider,
|
||||
kNupharExecutionProvider, kVitisAIExecutionProvider, kArmNNExecutionProvider,
|
||||
kAclExecutionProvider, kDmlExecutionProvider, kCpuExecutionProvider};
|
||||
kNupharExecutionProvider, kVitisAIExecutionProvider,
|
||||
kNnapiExecutionProvider,
|
||||
kArmNNExecutionProvider, kAclExecutionProvider,
|
||||
kDmlExecutionProvider, kCpuExecutionProvider};
|
||||
return all_providers;
|
||||
}
|
||||
|
||||
|
|
@ -716,6 +719,13 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
|
|||
} else if (type == kDmlExecutionProvider) {
|
||||
#ifdef USE_DML
|
||||
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_DML(0));
|
||||
#endif
|
||||
} else if (type == kNnapiExecutionProvider) {
|
||||
#if defined(USE_NNAPI)
|
||||
#if !defined(__ANDROID__)
|
||||
LOGS_DEFAULT(WARNING) << "NNAPI execution provider can only be used to generate ORT format model in this build.";
|
||||
#endif
|
||||
RegisterExecutionProvider(sess, *onnxruntime::CreateExecutionProviderFactory_Nnapi(0));
|
||||
#endif
|
||||
} else {
|
||||
// unknown provider
|
||||
|
|
@ -902,7 +912,10 @@ void addGlobalMethods(py::module& m, const Environment& env) {
|
|||
onnxruntime::CreateExecutionProviderFactory_ArmNN(0),
|
||||
#endif
|
||||
#ifdef USE_DML
|
||||
onnxruntime::CreateExecutionProviderFactory_DML(0)
|
||||
onnxruntime::CreateExecutionProviderFactory_DML(0),
|
||||
#endif
|
||||
#ifdef USE_NNAPI
|
||||
onnxruntime::CreateExecutionProviderFactory_NNAPI(0),
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def create_config_file(optimized_model_path, config_file_path):
|
|||
exclude_unused_ops(optimized_model_path, config_path=None, ort_root=None, output_config_path=config_file_path)
|
||||
|
||||
|
||||
def convert(model_path: str):
|
||||
def convert(model_path: str, optimization_level: ort.GraphOptimizationLevel, use_nnapi: bool):
|
||||
models = glob.glob(os.path.join(model_path, '**', '*.onnx'), recursive=True)
|
||||
|
||||
if len(models) == 0:
|
||||
|
|
@ -42,19 +42,41 @@ def convert(model_path: str):
|
|||
|
||||
so = ort.SessionOptions()
|
||||
so.optimized_model_filepath = onnx_target_path
|
||||
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED # Skip NCHWc optimizations
|
||||
so.graph_optimization_level = optimization_level
|
||||
|
||||
print("Optimizing ONNX model {}".format(model))
|
||||
# creating the session will result in the optimized model being saved
|
||||
_ = ort.InferenceSession(model, sess_options=so)
|
||||
# creating the session will result in the optimized model being saved. we use just the CPU EP for this step
|
||||
providers = ['CPUExecutionProvider']
|
||||
_ = ort.InferenceSession(model, sess_options=so, providers=providers)
|
||||
|
||||
# special case if we're enabling a compiling EP like NNAPI. we don't currently have a way to read the
|
||||
# required ops from an ORT format model, so we need an ONNX model that is only optimized to 'basic' level
|
||||
# to ensure all the nodes that NNAPI may take still exist. we can merge the required operators from that
|
||||
# with the required operators from an ONNX model optimized to a higher level (if the user requested that).
|
||||
# we must use this model with creating the ORT format model to maximize the nodes that NNAPI can potentially
|
||||
# take, so replace onnx_target_path with the new path.
|
||||
if use_nnapi and \
|
||||
(optimization_level == ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED or
|
||||
optimization_level == ort.GraphOptimizationLevel.ORT_ENABLE_ALL):
|
||||
onnx_target_path = os.path.join(tmpdirname, re.sub('.onnx$', '.optimized.basic.onnx', model_filename))
|
||||
so.optimized_model_filepath = onnx_target_path
|
||||
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
_ = ort.InferenceSession(model, sess_options=so, providers=providers)
|
||||
|
||||
# Second, convert optimized ONNX model to ORT format
|
||||
# we enable the compiling EPs when we generate the ORT format model so that we preserve the nodes it may
|
||||
# take, but allow optimization on any others
|
||||
if use_nnapi:
|
||||
# providers are priority based, so register NNAPI first
|
||||
providers.insert(0, 'NnapiExecutionProvider')
|
||||
|
||||
so.optimized_model_filepath = ort_target_path
|
||||
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL # Convert model as-is so we don't change the kernels in this step # noqa
|
||||
# Use original optimization level so that if NNAPI is enabled we optimize nodes it is not taking
|
||||
so.graph_optimization_level = optimization_level
|
||||
so.add_session_config_entry('session.save_model_format', 'ORT')
|
||||
|
||||
print("Converting optimized ONNX model to ORT format model {}".format(ort_target_path))
|
||||
_ = ort.InferenceSession(onnx_target_path, sess_options=so)
|
||||
_ = ort.InferenceSession(onnx_target_path, sess_options=so, providers=providers)
|
||||
|
||||
# orig_size = os.path.getsize(onnx_target_path)
|
||||
# new_size = os.path.getsize(ort_target_path)
|
||||
|
|
@ -65,6 +87,22 @@ def convert(model_path: str):
|
|||
create_config_file(tmpdirname, os.path.join(model_path, 'required_operators.config'))
|
||||
|
||||
|
||||
def _get_optimization_level(level):
|
||||
if level == 'disable':
|
||||
return ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
if level == 'basic':
|
||||
# Constant folding and other optimizations that only use ONNX operators
|
||||
return ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
if level == 'extended':
|
||||
# Optimizations using custom operators, excluding NCHWc optimizations
|
||||
return ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
||||
if level == 'all':
|
||||
# all optimizations, including NCHWc (which has hardware specific logic)
|
||||
return ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
raise ValueError('Invalid optimization level of ' + level)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
os.path.basename(__file__),
|
||||
|
|
@ -75,14 +113,30 @@ def parse_args():
|
|||
This configuration file should be used as input to the minimal build'''
|
||||
)
|
||||
|
||||
parser.add_argument('--use_nnapi', action='store_true',
|
||||
help='Enable the NNAPI Execution Provider when creating models and determining required '
|
||||
'operators. Note that this will limit the optimizations possible on nodes that the '
|
||||
'NNAPI execution provider takes, in order to preserve those nodes in the ORT format '
|
||||
'model.')
|
||||
|
||||
parser.add_argument('--optimization_level', default='extended',
|
||||
choices=['disable', 'basic', 'extended', 'all'],
|
||||
help="Level to optimize ONNX model with, prior to converting to ORT format model. "
|
||||
"These map to the onnxruntime.GraphOptimizationLevel values. "
|
||||
"NOTE: It is NOT recommended to use 'all' unless you are creating the ORT format model on "
|
||||
"the device you will run it on, as the generated model may not be valid on other hardware."
|
||||
)
|
||||
|
||||
parser.add_argument('model_path', help='Provide path to directory containing ONNX model/s to convert. '
|
||||
'Files with .onnx extension will be processed.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
convert(args.model_path)
|
||||
optimization_level = _get_optimization_level(args.optimization_level)
|
||||
convert(args.model_path, optimization_level, args.use_nnapi)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
Loading…
Reference in a new issue