diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index e755b4bfa6..4540d04543 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -219,7 +219,24 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, } } - if (!can_constant_fold_qdq_node_unit) { + bool can_constant_fold_dq_node = false; + + // Another scenario where dequantizing on initializer (ex: initializer -> DQ -> bias of X (Gemm or Conv)) can be constant folded is if: + // - the DQ node does not produce a graph output + // - The data type of initializer is not FP16, INT8, UNIT8 and INT4 + // - Does X need to be either Gemm or Conv ? + if (!can_constant_fold_qdq_node_unit && dequantize_linear_on_initializer_) { + if (!graph.NodeProducesGraphOutput(*node)) { // DQ does not produce graph output + const auto* input_def = node->InputDefs()[0]; // Get NodeArg of the initializer of the DequantizeLinear node; + auto data_type = input_def->TypeAsProto()->tensor_type().elem_type(); + if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || + data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) { + can_constant_fold_dq_node = true; + } + } + } + + if (!can_constant_fold_qdq_node_unit && !can_constant_fold_dq_node) { continue; } } diff --git a/onnxruntime/core/optimizer/constant_folding.h b/onnxruntime/core/optimizer/constant_folding.h index 14eb2a9c5f..72f0036a0a 100644 --- a/onnxruntime/core/optimizer/constant_folding.h +++ b/onnxruntime/core/optimizer/constant_folding.h @@ -32,6 +32,7 @@ class ConstantFolding : public GraphTransformer { Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; bool skip_dequantize_linear_; + bool dequantize_linear_on_initializer_ = true; const ConfigOptions& config_options_; const InlinedHashSet excluded_initializers_; const IExecutionProvider& execution_provider_;