constant fold DQ on INT32/UINT16 initializer

This commit is contained in:
Chi Lo 2025-01-14 11:58:39 -08:00
parent 704523c2d8
commit d2c822cb3a
2 changed files with 19 additions and 1 deletions

View file

@ -219,7 +219,24 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
}
}
if (!can_constant_fold_qdq_node_unit) {
bool can_constant_fold_dq_node = false;
// Another scenario where dequantizing on initializer (ex: initializer -> DQ -> bias of X (Gemm or Conv)) can be constant folded is if:
// - the DQ node does not produce a graph output
// - The data type of initializer is not FP16, INT8, UNIT8 and INT4
// - Does X need to be either Gemm or Conv ?
if (!can_constant_fold_qdq_node_unit && dequantize_linear_on_initializer_) {
if (!graph.NodeProducesGraphOutput(*node)) { // DQ does not produce graph output
const auto* input_def = node->InputDefs()[0]; // Get NodeArg of the initializer of the DequantizeLinear node;
auto data_type = input_def->TypeAsProto()->tensor_type().elem_type();
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
can_constant_fold_dq_node = true;
}
}
}
if (!can_constant_fold_qdq_node_unit && !can_constant_fold_dq_node) {
continue;
}
}

View file

@ -32,6 +32,7 @@ class ConstantFolding : public GraphTransformer {
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
bool skip_dequantize_linear_;
bool dequantize_linear_on_initializer_ = true;
const ConfigOptions& config_options_;
const InlinedHashSet<std::string> excluded_initializers_;
const IExecutionProvider& execution_provider_;