[QNN EP] Add session option to disable fallback to default CPU EP (#16016)

### Description
Adds the session config option `disable_cpu_ep_fallback` to allow the
user to prevent the CPU EP from handling
nodes not supported by other execution providers.

```C++
// Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are
// assigned (i.e., "fallback") to the CPU EP by default.
//
// This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP.
// If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot
// fully support all of the nodes in the graph.
//
// It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation
// will also fail with an error.
//
// Option values:
// - "0": CPU EP fallback is not disabled. [DEFAULT]
// - "1": CPU EP fallback is disabled.
static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback";
```

#### Example use
```C++
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/onnxruntime_session_options_config_keys.h"

int main(int argc, char** argv) {
    Ort::SessionOptions so;
    so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1");  // Disable fallback to the CPU EP.

    onnxruntime::ProviderOptions options;
#if defined(_WIN32)
    options["backend_path"] = "QnnCpu.dll";
#else
    options["backend_path"] = "libQnnCpu.so";
#endif

    so.AppendExecutionProvider("QNN", options);

    const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "qnn_ep_partial_support.onnx";
    Ort::Session session(*ort_env, ort_model_path, so);  // Throws exception if nodes fallback to CPU
    // ...
```

### Motivation and Context
Makes it easier for application developers to ensure that the entire
model runs on specific EPs. This is critical for Qualcomm/scenarios. If
the compute cannot be offloaded to the NPU, running on CPU is not
acceptable. (could be the difference between 90 second inference and 6
seconds inference)

---------

Co-authored-by: Pranav Sharma <prs@microsoft.com>
This commit is contained in:
Adrian Lizarraga 2023-05-23 17:56:32 -07:00 committed by GitHub
parent b9d39e3405
commit efc84a43e8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 214 additions and 16 deletions

View file

@ -197,3 +197,18 @@ static const char* const kOrtSessionOptionsConfigIntraOpThreadAffinities = "sess
// 3) after the L1 transformers are applied to the updated graph.
// The model will be saved to filename post_layout_transform_step_<step_number>.onnx.
static const char* const kDebugLayoutTransformation = "session.debug_layout_transformation";
// Graph nodes that are not supported by the execution providers (EPs) explicitly added to the session are
// assigned (i.e., "fallback") to the CPU EP by default.
//
// This option allows the user to disable the fallback of unsupported graph nodes to the CPU EP.
// If this option is set to "1", session creation will fail if the execution providers other than the CPU EP cannot
// fully support all of the nodes in the graph.
//
// It is invalid to set this option and explicitly add the CPU EP to the session. In this case, session creation
// will also fail with an error.
//
// Option values:
// - "0": CPU EP fallback is not disabled. [DEFAULT]
// - "1": CPU EP fallback is disabled.
static const char* const kOrtSessionOptionsDisableCPUEPFallback = "session.disable_cpu_ep_fallback";

View file

@ -8,6 +8,7 @@
#include "core/framework/allocatormgr.h"
#include "core/framework/compute_capability.h"
#include "core/graph/graph_viewer.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/framework/kernel_registry.h"
#include "core/providers/partitioning_utils.h"
@ -72,9 +73,15 @@ void QNNExecutionProvider::ParseHtpPerformanceMode(std::string htp_performance_m
}
}
QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map)
QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map,
const SessionOptions* session_options)
: IExecutionProvider{onnxruntime::kQnnExecutionProvider, true},
runtime_options_(provider_options_map) {
if (session_options) {
disable_cpu_ep_fallback_ = session_options->config_options.GetConfigOrDefault(
kOrtSessionOptionsDisableCPUEPFallback, "0") == "1";
}
static const std::string CONTEXT_CACHE_ENABLED = "qnn_context_cache_enable";
auto context_cache_enabled_pos = runtime_options_.find(CONTEXT_CACHE_ENABLED);
if (context_cache_enabled_pos != runtime_options_.end()) {
@ -310,14 +317,37 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(),
load_from_cached_context, logger);
// Helper function that returns a string that lists all unsupported nodes.
// Ex: { name: mul_123, type: Mul }, {}, ...
auto get_unsupported_node_names = [&node_unit_holder, &supported_nodes]() -> std::string {
std::stringstream ss;
const size_t num_node_units = node_unit_holder.size();
for (size_t i = 0; i < num_node_units; ++i) {
const auto& node_unit = node_unit_holder[i];
if (supported_nodes.find(&node_unit->GetNode()) == supported_nodes.end()) {
ss << "{ name: " << node_unit->Name() << ", type: " << node_unit->OpType() << " }";
if (i == num_node_units - 1) {
ss << ", ";
}
}
}
return ss.str();
};
if (supported_nodes.empty()) {
LOGS(logger, INFO) << "Number of partitions supported by QNN EP: 0";
return result;
} else if (supported_nodes.size() == 1) {
const auto* node = *supported_nodes.begin();
if (node->OpType() == "QuantizeLinear" || node->OpType() == "DequantizeLinear") {
LOGS(logger, INFO) << "It doesn't make sense just run a Q/DQ node on HTP.";
LOGS(logger, INFO) << "Number of partitions supported by QNN EP: 0";
LOGS(logger, WARNING) << "It doesn't make sense just run a Q/DQ node on HTP.";
LOGS(logger, WARNING) << "Number of partitions supported by QNN EP: 0";
if (disable_cpu_ep_fallback_) {
LOGS(logger, ERROR) << "Unsupported nodes in QNN EP: " << get_unsupported_node_names();
}
return result;
}
}
@ -338,6 +368,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
[](const auto& partition) -> size_t {
return partition && partition->sub_graph ? partition->sub_graph->nodes.size() : 0;
});
const size_t num_nodes_in_graph = static_cast<size_t>(graph_viewer.NumberOfNodes());
if (load_from_cached_context && 1 == num_of_partitions) {
rt = qnn_backend_manager_->ValidateWithContextFile(GetFileNameFromModelPath(graph_viewer.ModelPath()),
@ -349,14 +380,20 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
}
if (num_of_partitions > 1) {
ORT_ENFORCE(!context_cache_enabled_, "Only support singel partition for context cache feature.");
ORT_ENFORCE(!context_cache_enabled_, "Only support single partition for context cache feature.");
}
const auto summary_msg = MakeString("Number of partitions supported by QNN EP: ", num_of_partitions,
", number of nodes in the graph: ", graph_viewer.NumberOfNodes(),
", number of nodes in the graph: ", num_nodes_in_graph,
", number of nodes supported by QNN: ", num_of_supported_nodes);
LOGS(logger, INFO) << summary_msg;
// Print list of unsupported nodes to the ERROR logger if the CPU EP
// has been disabled for this inference session.
if (disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) {
LOGS(logger, ERROR) << "Unsupported nodes in QNN EP: " << get_unsupported_node_names();
}
return result;
}

View file

@ -4,6 +4,7 @@
#pragma once
#include "core/framework/execution_provider.h"
#include "core/framework/session_options.h"
#include <string>
#include "core/providers/qnn/builder/qnn_backend_manager.h"
#include "core/providers/qnn/builder/qnn_model.h"
@ -13,7 +14,7 @@ namespace onnxruntime {
// Logical device representation.
class QNNExecutionProvider : public IExecutionProvider {
public:
explicit QNNExecutionProvider(const ProviderOptions& provider_options_map);
explicit QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options);
virtual ~QNNExecutionProvider() = default;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QNNExecutionProvider);
@ -69,6 +70,7 @@ class QNNExecutionProvider : public IExecutionProvider {
uint32_t rpc_control_latency_ = 0;
bool context_cache_enabled_ = false;
std::string context_cache_path_ = "";
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
};
} // namespace onnxruntime

View file

@ -9,22 +9,25 @@
namespace onnxruntime {
struct QNNProviderFactory : IExecutionProviderFactory {
QNNProviderFactory(const ProviderOptions& provider_options_map) : provider_options_map_(provider_options_map) {
QNNProviderFactory(const ProviderOptions& provider_options_map, const SessionOptions* session_options)
: provider_options_map_(provider_options_map), session_options_(session_options) {
}
~QNNProviderFactory() override {
}
std::unique_ptr<IExecutionProvider> CreateProvider() override {
return std::make_unique<QNNExecutionProvider>(provider_options_map_);
return std::make_unique<QNNExecutionProvider>(provider_options_map_, session_options_);
}
private:
ProviderOptions provider_options_map_;
const SessionOptions* session_options_;
};
std::shared_ptr<IExecutionProviderFactory> QNNProviderFactoryCreator::Create(const ProviderOptions& provider_options_map) {
return std::make_shared<onnxruntime::QNNProviderFactory>(provider_options_map);
std::shared_ptr<IExecutionProviderFactory> QNNProviderFactoryCreator::Create(const ProviderOptions& provider_options_map,
const SessionOptions* session_options) {
return std::make_shared<onnxruntime::QNNProviderFactory>(provider_options_map, session_options);
}
} // namespace onnxruntime

View file

@ -9,7 +9,10 @@
#include "core/providers/providers.h"
namespace onnxruntime {
struct SessionOptions;
struct QNNProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& provider_options_map);
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& provider_options_map,
const SessionOptions* session_options);
};
} // namespace onnxruntime

View file

@ -1557,6 +1557,42 @@ common::Status InferenceSession::Initialize() {
}
}
const bool disable_cpu_ep_fallback = session_options_.config_options.GetConfigOrDefault(
kOrtSessionOptionsDisableCPUEPFallback, "0") == "1";
// Handle the option to disable the fallback of graph nodes to the CPU EP.
// If the user disabled fallback, but also explicitly added the CPU EP to the session, return an error status.
// If the user disabled fallback and any graph node is assigned to the CPU EP, return an error status.
if (disable_cpu_ep_fallback) {
// Returns true if any graph nodes have been assigned to the CPU EP.
auto are_nodes_assigned_to_cpu_ep = [](const Graph& graph) -> bool {
for (const auto& node : graph.Nodes()) {
const auto& node_provider = node.GetExecutionProviderType();
if (node_provider.empty() || node_provider == onnxruntime::kCpuExecutionProvider) {
return true;
}
}
return false;
};
if (!execution_providers_.GetCpuProviderWasImplicitlyAdded()) {
const char* err_msg =
"Conflicting session configuration: explicitly added the CPU EP to the "
"session, but also disabled fallback to the CPU EP via session configuration options.";
LOGS(*session_logger_, ERROR) << err_msg;
ORT_RETURN_IF_ERROR_SESSIONID_(ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, err_msg));
} else if (are_nodes_assigned_to_cpu_ep(graph)) {
const char* err_msg =
"This session contains graph nodes that are assigned to the default CPU EP, "
"but fallback to CPU EP has been explicitly disabled by the user.";
LOGS(*session_logger_, ERROR) << err_msg;
ORT_RETURN_IF_ERROR_SESSIONID_(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, err_msg));
}
}
// Update temporary copies of metadata, input- and output definitions to the same state as the resolved graph
ORT_RETURN_IF_ERROR_SESSIONID_(SaveModelMetadata(*model_));
#else // !defined(ORT_MINIMAL_BUILD)

View file

@ -68,7 +68,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
if (strcmp(provider_name, "QNN") == 0) {
#if defined(USE_QNN)
options->provider_factories.push_back(QNNProviderFactoryCreator::Create(provider_options));
options->provider_factories.push_back(QNNProviderFactoryCreator::Create(provider_options, &(options->value)));
#else
status = create_not_supported_status();
#endif

View file

@ -841,7 +841,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
#ifdef USE_QNN
auto cit = provider_options_map.find(type);
return onnxruntime::QNNProviderFactoryCreator::Create(
cit == provider_options_map.end() ? ProviderOptions{} : cit->second)
cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options)
->CreateProvider();
#endif
} else {

View file

@ -5,6 +5,7 @@
#include "core/session/onnxruntime_cxx_api.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU
#include "core/session/inference_session.h"
#include "test/providers/qnn/qnn_test_utils.h"
@ -31,8 +32,15 @@ namespace test {
// Loads a simple ONNX model that adds floats.
TEST(QnnEP, TestAddEpUsingPublicApi) {
{
// C++ API test
Ort::SessionOptions so;
// Can only enforce that model runs on QNN in linux CI machines
// because they support the CPU backend and emulate the HPT backend.
// TODO: Remove #ifdef when Windows Arm64 machines support the CPU backend.
#if defined(__linux__)
so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Disable fallback to the CPU EP.
#endif
onnxruntime::ProviderOptions options;
#if defined(_WIN32)
@ -63,6 +71,100 @@ TEST(QnnEP, TestAddEpUsingPublicApi) {
}
}
// Tests the `session.disable_cpu_ep_fallback` configuration option when the backend cannot be loaded.
// When the option is enabled, session creation throws an exception because the backend cannot be found.
TEST(QnnEP, TestDisableCPUFallback_BackendNotFound) {
{
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Disable fallback to the CPU EP.
onnxruntime::ProviderOptions options;
#if defined(_WIN32)
options["backend_path"] = "DoesNotExist.dll"; // Invalid backend path!
#else
options["backend_path"] = "libDoesNotExist.so"; // Invalid backend path!
#endif
so.AppendExecutionProvider("QNN", options);
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "constant_floats.onnx";
try {
Ort::Session session(*ort_env, ort_model_path, so);
FAIL(); // Should not get here!
} catch (const Ort::Exception& excpt) {
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL);
ASSERT_THAT(excpt.what(), testing::HasSubstr("This session contains graph nodes that are assigned to the default "
"CPU EP, but fallback to CPU EP has been explicitly disabled by "
"the user."));
}
}
}
// Tests the `session.disable_cpu_ep_fallback` configuration option when the entire model cannot be assigned to QNN EP.
// When the option is enabled, Session creation should throw an exception.
TEST(QnnEP, TestDisableCPUFallback_ModelNotFullySupported) {
{
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Disable fallback to the CPU EP.
onnxruntime::ProviderOptions options;
#if defined(_WIN32)
options["backend_path"] = "QnnCpu.dll";
#else
options["backend_path"] = "libQnnCpu.so";
#endif
so.AppendExecutionProvider("QNN", options);
// QNN EP doesn't support MatMulInteger.
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "qnn_ep_partial_support.onnx";
try {
Ort::Session session(*ort_env, ort_model_path, so);
FAIL(); // Should not get here!
} catch (const Ort::Exception& excpt) {
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL);
ASSERT_THAT(excpt.what(), testing::HasSubstr("This session contains graph nodes that are assigned to the default "
"CPU EP, but fallback to CPU EP has been explicitly disabled by "
"the user."));
}
}
}
// Tests invalid use of the `session.disable_cpu_ep_fallback` configuration option.
// It is invalid to set the option and explicitly add the CPU EP to the session.
TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) {
{
Ort::SessionOptions so;
so.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Disable fallback to the CPU EP.
onnxruntime::ProviderOptions options;
#if defined(_WIN32)
options["backend_path"] = "QnnCpu.dll";
#else
options["backend_path"] = "libQnnCpu.so";
#endif
so.AppendExecutionProvider("QNN", options);
// Invalid! Adds CPU EP to session, but also disables CPU fallback.
Ort::Status status(OrtSessionOptionsAppendExecutionProvider_CPU(so, 1));
const ORTCHAR_T* ort_model_path = ORT_MODEL_FOLDER "constant_floats.onnx";
try {
Ort::Session session(*ort_env, ort_model_path, so);
FAIL(); // Should not get here!
} catch (const Ort::Exception& excpt) {
ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_INVALID_ARGUMENT);
ASSERT_THAT(excpt.what(), testing::HasSubstr("Conflicting session configuration: explicitly added the CPU EP to the "
"session, but also disabled fallback to the CPU EP via session "
"configuration options."));
}
}
}
// Helper function that runs an ONNX model with a NHWC Resize operator to test that
// type/shape inference succeeds during layout transformation.
// Refer to onnxruntime/core/graph/contrib_ops/nhwc_inference_context.h.

Binary file not shown.

View file

@ -229,7 +229,7 @@ std::unique_ptr<IExecutionProvider> DefaultQnnExecutionProvider() {
backend_path = "./QnnCpu.dll";
#endif
provider_options_map["backend_path"] = backend_path;
return QNNProviderFactoryCreator::Create(provider_options_map)->CreateProvider();
return QNNProviderFactoryCreator::Create(provider_options_map, nullptr)->CreateProvider();
#else
return nullptr;
#endif
@ -237,7 +237,7 @@ std::unique_ptr<IExecutionProvider> DefaultQnnExecutionProvider() {
std::unique_ptr<IExecutionProvider> QnnExecutionProviderWithOptions(const ProviderOptions& options) {
#ifdef USE_QNN
return QNNProviderFactoryCreator::Create(options)->CreateProvider();
return QNNProviderFactoryCreator::Create(options, nullptr)->CreateProvider();
#else
ORT_UNUSED_PARAMETER(options);
return nullptr;