mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
constant fold DQ on INT32/UINT16 initializer
This commit is contained in:
parent
704523c2d8
commit
d2c822cb3a
2 changed files with 19 additions and 1 deletions
|
|
@ -219,7 +219,24 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
}
|
||||
|
||||
if (!can_constant_fold_qdq_node_unit) {
|
||||
bool can_constant_fold_dq_node = false;
|
||||
|
||||
// Another scenario where dequantizing on initializer (ex: initializer -> DQ -> bias of X (Gemm or Conv)) can be constant folded is if:
|
||||
// - the DQ node does not produce a graph output
|
||||
// - The data type of initializer is not FP16, INT8, UNIT8 and INT4
|
||||
// - Does X need to be either Gemm or Conv ?
|
||||
if (!can_constant_fold_qdq_node_unit && dequantize_linear_on_initializer_) {
|
||||
if (!graph.NodeProducesGraphOutput(*node)) { // DQ does not produce graph output
|
||||
const auto* input_def = node->InputDefs()[0]; // Get NodeArg of the initializer of the DequantizeLinear node;
|
||||
auto data_type = input_def->TypeAsProto()->tensor_type().elem_type();
|
||||
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
|
||||
data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT16) {
|
||||
can_constant_fold_dq_node = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!can_constant_fold_qdq_node_unit && !can_constant_fold_dq_node) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ class ConstantFolding : public GraphTransformer {
|
|||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
bool skip_dequantize_linear_;
|
||||
bool dequantize_linear_on_initializer_ = true;
|
||||
const ConfigOptions& config_options_;
|
||||
const InlinedHashSet<std::string> excluded_initializers_;
|
||||
const IExecutionProvider& execution_provider_;
|
||||
|
|
|
|||
Loading…
Reference in a new issue