diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index d7039cb4b7..135e53b447 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -91,7 +91,8 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"AveragePool", {}}, {"LeakyRelu", {}}, @@ -108,16 +109,31 @@ void UnaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { void BinaryOpQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { // 4 nodes. 2 x DQ for inputs, target, Q // Replace with internal QLinear version of operator. Delete all original nodes. - const std::string action_name{"2DQ"}; - std::unique_ptr action = std::make_unique(kMSDomain); #if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); - qdq_selector_action_registry.RegisterSelectorAndAction(action_name, - {{"Add", {}}, - {"Mul", {}}}, - std::move(selector), - std::move(action)); + { + const std::string action_name{"2DQ_Mul"}; + std::unique_ptr action = std::make_unique(kMSDomain); + + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); + qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + {{"Mul", {}}}, + std::move(selector), + std::move(action)); + } + + { + const std::string action_name{"2DQ_Add"}; + std::unique_ptr action = std::make_unique(kMSDomain); + + std::unique_ptr selector = std::make_unique(); + + qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + {{"Add", {}}}, + std::move(selector), + std::move(action)); + } #else qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); @@ -195,7 +211,8 @@ void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Gemm", {}}}, std::move(selector), @@ -215,7 +232,9 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { std::unique_ptr action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) - std::unique_ptr selector = std::make_unique(); + + std::vector providers = {kCpuExecutionProvider}; + std::unique_ptr selector = std::make_unique(providers); qdq_selector_action_registry.RegisterSelectorAndAction(action_name, {{"Where", {}}}, std::move(selector), @@ -250,8 +269,8 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( "QDQSelectorActionTransformer", CreateSelectorActionRegistry(is_int8_allowed), apply_context, - // this transformer is only compatible with the CPU EP - {kCpuExecutionProvider}} { + // this transformer is only compatible with the CPU and DML EP + {kCpuExecutionProvider, kDmlExecutionProvider}} { } } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 565afcc67e..20521c5404 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -85,6 +85,14 @@ std::optional NodeGroupSelector::GetQDQSelection(const GraphViewer& g } std::optional BaseSelector::Select(const GraphViewer& graph_viewer, const Node& node) const { + + const std::string_view node_ep = node.GetExecutionProviderType(); + + if (!compatible_providers_.empty() && + std::find(compatible_providers_.begin(), compatible_providers_.end(), node_ep) == compatible_providers_.end()) { + return std::nullopt; + } + const auto qdq_group = node_group_selector_->GetQDQSelection(graph_viewer, node); if (!qdq_group.has_value()) { return std::nullopt; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index ab9ad45697..6168761175 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -179,12 +179,16 @@ class BaseSelector : public NodeSelector { // We std::move SelectorActionRegistry into the SelectorActionTransformer so this class needs to have a move ctor BaseSelector(BaseSelector&& rhs) noexcept - : node_group_selector_{std::move(rhs.node_group_selector_)} { + : node_group_selector_{std::move(rhs.node_group_selector_)}, + compatible_providers_{std::move(rhs.compatible_providers_)} { } protected: - BaseSelector(std::unique_ptr node_group_selector) - : node_group_selector_{std::move(node_group_selector)} {} + BaseSelector(std::unique_ptr node_group_selector, std::vector compatible_providers = {}) + : node_group_selector_{std::move(node_group_selector)} + { + compatible_providers_.assign(compatible_providers.begin(), compatible_providers.end()); + } // override if you need to adjust the values in NodesToOptimize. // e.g. add entries for missing optional DQ inputs or set num_inputs to handle variadic inputs @@ -193,6 +197,7 @@ class BaseSelector : public NodeSelector { private: std::unique_ptr node_group_selector_; + std::vector compatible_providers_; }; class DropQDQNodesSelector : public BaseSelector { @@ -207,12 +212,12 @@ class DropDQNodesSelector : public BaseSelector { class UnarySelector : public BaseSelector { public: - UnarySelector() : BaseSelector(std::make_unique()) {} + UnarySelector(std::vector compatible_providers) : BaseSelector(std::make_unique(), compatible_providers) {} }; class BinarySelector : public BaseSelector { public: - BinarySelector() : BaseSelector(std::make_unique()) {} + BinarySelector(std::vector compatible_providers = {}) : BaseSelector(std::make_unique(), compatible_providers) {} }; // Variadic DQ nodes -> node -> Q @@ -240,7 +245,7 @@ class ConvSelector : public BaseSelector { }; class WhereSelector : public BaseSelector { public: - WhereSelector() : BaseSelector(std::make_unique()) {} + WhereSelector(std::vector compatible_providers = {}) : BaseSelector(std::make_unique(), compatible_providers) {} }; // 2 DQ nodes for input -> node -> optional Q if QLinearMatMul, MatMulIntegerToFloat if not class MatMulSelector : public BaseSelector { @@ -253,8 +258,8 @@ class MatMulSelector : public BaseSelector { // Output: optional Q node for Y class GemmSelector : public BaseSelector { public: - GemmSelector() - : BaseSelector(std::make_unique()) {} + GemmSelector(std::vector compatible_providers = {}) + : BaseSelector(std::make_unique(), compatible_providers) {} void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override; }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f88bf8cf24..34aa109345 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -555,26 +555,6 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr session_options_.enable_mem_pattern = false; } - // Default this option to true when the DML EP is registered. - // This should be removed if QDQ is supported for DML through QDQSelectorActionTransformer and the DML EP does not - // rely on the constant folding pass for DequantizeLinear. - optional disable_quant_qdq = session_options_.config_options.GetConfigEntry(kOrtSessionOptionsDisableQuantQDQ); - - if (disable_quant_qdq == std::nullopt) { - LOGS(*session_logger_, INFO) - << "QDQ quantization is not supported while using the DML Execution Provider. " - << "So disabling it for this session since it uses the DML Execution Provider."; - - auto st = session_options_.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1"); - if (!st.IsOK()) { - return st; - } - } else if (*disable_quant_qdq != "1") { - LOGS(*session_logger_, WARNING) - << "QDQ quantization is not supported while using the DML Execution Provider. " - << "It is enabled within session options which may result in lower performance."; - } - // Parallel execution mode does not support DML EP if (session_options_.execution_mode != ExecutionMode::ORT_SEQUENTIAL) { LOGS(*session_logger_, INFO)