diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index ba2b87b5aa..9684394da0 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -63,7 +63,7 @@ #ifdef MLAS_TARGET_AMD64_IX86 #include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h" #endif -#include "core/optimizer/qdq_transformer/bias_quantization.h" +#include "core/optimizer/qdq_transformer/weight_bias_quantization.h" #include "core/optimizer/qdq_transformer/clip_quantizelinear.h" #include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" @@ -245,7 +245,7 @@ InlinedVector> GenerateTransformers( if (!disable_quant_qdq) { transformers.emplace_back(std::make_unique()); - transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique()); // EnsureUniqueDQForNodeUnit is actually a required graph transformation. The unique DQ per QDQ node unit input // condition that it ensures is important for the partitioning that happens after Level1 optimizers are run. diff --git a/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc deleted file mode 100644 index 9e9665e14e..0000000000 --- a/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/optimizer/qdq_transformer/bias_quantization.h" - -#include "core/common/common.h" -#include "core/graph/graph_utils.h" -#include "core/graph/graph_viewer.h" -#include "core/optimizer/utils.h" -#include "core/optimizer/qdq_transformer/qdq_util.h" - -namespace onnxruntime { - -Status BiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - const GraphViewer graph_viewer{graph}; - const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); - for (const auto node_idx : node_indices) { - auto* node_ptr = graph.GetNode(node_idx); - if (!node_ptr) { - continue; - } - - Node& node = *node_ptr; - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - - const auto& input_defs = node.InputDefs(); - - // It's Conv/Gemm node with an initializer bias. - if ((node.OpType() != "Conv" && node.OpType() != "Gemm") || input_defs.size() < 3 || !input_defs[2]->Exists() || - !graph_utils::IsInitializer(graph, input_defs[2]->Name(), true)) { - continue; - } - - auto bias_shape = input_defs[2]->Shape(); - if (!bias_shape || bias_shape->dim_size() != 1) { - continue; - } - int64_t bias_size = bias_shape->dim(0).dim_value(); - - // input_0 and input_1 are outputs of DequantizeLinear nodes. - const Node* parent_node_0 = graph.GetProducerNode(input_defs[0]->Name()); - const Node* parent_node_1 = graph.GetProducerNode(input_defs[1]->Name()); - if (!parent_node_0 || !parent_node_1 || parent_node_0->OpType() != QDQ::DQOpName || - parent_node_1->OpType() != QDQ::DQOpName) { - continue; - } - - Node& dq_0 = *graph.GetNode(parent_node_0->Index()); - Node& dq_1 = *graph.GetNode(parent_node_1->Index()); - - // Currently we require input_0 is per-tensor scale. - if (!optimizer_utils::IsScalar(*dq_0.InputDefs()[1])) { - continue; - } - - // For input_1, it's either per-tensor scale or per-channel scale on specific axis (0 for Conv and 1 for Gemm). - bool is_per_tensor_scale = true; - if (!optimizer_utils::IsScalar(*dq_1.InputDefs()[1])) { - is_per_tensor_scale = false; - auto weight_scale_shape = dq_1.InputDefs()[1]->Shape(); - if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() || - weight_scale_shape->dim(0).dim_value() != bias_size) { - continue; - } - - const auto& dq_attrs = dq_1.GetAttributes(); - if (dq_attrs.find("block_size") != dq_attrs.end()) { - continue; - } - - int64_t axis = 1; - if (dq_attrs.find("axis") != dq_attrs.end()) { - axis = dq_attrs.at("axis").i(); - } - - int64_t expected_axis = 0; - if (node.OpType() == "Gemm") { - int64_t transB = 0; - if (const auto& attr = node.GetAttributes().find("transB"); attr != node.GetAttributes().end()) { - transB = attr->second.i(); - } - expected_axis = transB == 0 ? 1 : 0; - } - - if (axis != expected_axis) { - continue; - } - } - - // Bias is quantized to int32. - ONNX_NAMESPACE::TypeProto int32_type_proto; - int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); - auto scale_type = dq_1.InputDefs()[1]->TypeAsProto(); // Maybe per-tensor (scalar) or per-channel (1D) scale. - ONNX_NAMESPACE::TypeProto bias_dq_type; - bias_dq_type.mutable_tensor_type()->set_elem_type(scale_type->tensor_type().elem_type()); - bias_dq_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(bias_size); - - // scale = input_scale_0 * input_scale_1. - NodeArg& scale_node_arg = - graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"), scale_type); - Node& mul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Scale node", - {dq_0.MutableInputDefs()[1], dq_1.MutableInputDefs()[1]}, {&scale_node_arg}, nullptr, - node.Domain()); - - // fp_bias / scale. - NodeArg& bias_div_node_arg = - graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), &bias_dq_type); - Node& div_node = - graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node", - {node.MutableInputDefs()[2], &scale_node_arg}, {&bias_div_node_arg}, nullptr, node.Domain()); - graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1); - - // Round(fp_bias / scale). - NodeArg& bias_div_round_node_arg = - graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), &bias_dq_type); - Node& round_node = - graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node", - {&bias_div_node_arg}, {&bias_div_round_node_arg}, nullptr, node.Domain()); - graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0); - - // Cast(round(fp_bias / scale)) to int32. - NodeArg& bias_int32_node_arg = - graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &int32_type_proto); - Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias int32 node", - {&bias_div_round_node_arg}, {&bias_int32_node_arg}, nullptr, node.Domain()); - cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_INT32)); - graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0); - - // Bias DQ node produces output to Conv/Gemm node's input_2, with scale = input_scale_0 * input_scale_1, zp = 0. - NodeArg& bias_dq_node_arg = - graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), &bias_dq_type); - Node& dq_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node", - {&bias_int32_node_arg, &scale_node_arg}, {&bias_dq_node_arg}, nullptr, node.Domain()); - if (!is_per_tensor_scale) { - dq_node.AddAttribute("axis", static_cast(0)); - } - - graph.AddEdge(cast_node.Index(), dq_node.Index(), 0, 0); - graph.AddEdge(mul_node.Index(), dq_node.Index(), 0, 1); - node.MutableInputDefs()[2] = &bias_dq_node_arg; - graph.AddEdge(dq_node.Index(), node.Index(), 0, 2); - - modified = true; - } - - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h b/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h deleted file mode 100644 index 0297def260..0000000000 --- a/onnxruntime/core/optimizer/qdq_transformer/bias_quantization.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/optimizer/graph_transformer.h" - -namespace onnxruntime { - -/** - * @class BiasQuantization - * - * Some quantized models do not have Gemm/Conv's bias quantized. This optimization adds a subgraph to quantize the bias - * with scale = scale_input_0 * scale_input_1 and zero_point = 0. - * - * Normally the ConstantFolding optimizer would fold the bias initializer into an int32_t initializer, which is consumed - * by a DequantizeLinear node. - */ -class BiasQuantization : public GraphTransformer { - public: - BiasQuantization() noexcept : GraphTransformer("BiasQuantization") {} - - private: - Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc new file mode 100644 index 0000000000..a451e3ad60 --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc @@ -0,0 +1,214 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/qdq_transformer/weight_bias_quantization.h" + +#include "core/common/common.h" +#include "core/util/qmath.h" +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" + +namespace onnxruntime { + +Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + const GraphViewer graph_viewer{graph}; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + auto* node_ptr = graph.GetNode(node_idx); + if (!node_ptr) { + continue; + } + + Node& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (node.OpType() != "Conv" && node.OpType() != "ConvTranspose" && node.OpType() != "Gemm") { + continue; + } + + const auto& input_defs = node.InputDefs(); + const NodeArg* input_arg = input_defs[0]; + const NodeArg* weight_arg = input_defs[1]; + const NodeArg* bias_arg = input_defs.size() >= 3 && input_defs[2]->Exists() ? input_defs[2] : nullptr; + const Node* parent_node_0 = graph.GetProducerNode(input_arg->Name()); + const Node* parent_node_1 = graph.GetProducerNode(weight_arg->Name()); + + // Currently we require input is Dequantized with per-tensor scale. + if (!parent_node_0 || parent_node_0->OpType() != QDQ::DQOpName || + !optimizer_utils::IsScalar(*parent_node_0->InputDefs()[1])) { + continue; + } + + Node& dq_0 = *graph.GetNode(parent_node_0->Index()); + Node* dq_1 = nullptr; + const ONNX_NAMESPACE::TensorProto* weight_proto = nullptr; + if (parent_node_1 && parent_node_1->OpType() == QDQ::DQOpName) { + dq_1 = graph.GetNode(parent_node_1->Index()); + } else if (!graph_utils::IsInitializer(graph, weight_arg->Name(), true) || + !graph.GetInitializedTensor(weight_arg->Name(), weight_proto) || + weight_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + // Support float32 weight initializer only for now. + continue; + } + + int64_t bias_size = -1; + if (bias_arg) { + auto bias_shape = bias_arg->Shape(); + if (!graph_utils::IsInitializer(graph, bias_arg->Name(), true) || !bias_shape || bias_shape->dim_size() != 1) { + continue; + } + + bias_size = bias_shape->dim(0).dim_value(); + } + + // Nothing to do if neither weight nor bias is initializer. + if (dq_1 && bias_size == -1) { + continue; + } + + bool is_per_tensor_scale = true; + // If weight is quantized, it's either per-tensor or per-channel on specific axis (0 for Conv and 1 for Gemm). + if (dq_1 && !optimizer_utils::IsScalar(*dq_1->InputDefs()[1])) { + is_per_tensor_scale = false; + auto weight_scale_shape = dq_1->InputDefs()[1]->Shape(); + if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() || + weight_scale_shape->dim(0).dim_value() != bias_size) { + continue; + } + + const auto& dq_attrs = dq_1->GetAttributes(); + if (dq_attrs.find("block_size") != dq_attrs.end()) { + continue; + } + + int64_t axis = 1; + if (auto axis_iter = dq_attrs.find("axis"); axis_iter != dq_attrs.end()) { + axis = axis_iter->second.i(); + } + + int64_t expected_axis = 0; + if (node.OpType() == "Gemm") { + int64_t transB = 0; + const auto& gemm_attrs = node.GetAttributes(); + if (auto trans_b_iter = gemm_attrs.find("transB"); trans_b_iter != gemm_attrs.end()) { + transB = trans_b_iter->second.i(); + } + expected_axis = transB == 0 ? 1 : 0; + } + + if (axis != expected_axis) { + continue; + } + } + + NodeArg* weight_scale_arg = nullptr; + if (!dq_1) { + auto initializer = std::make_unique(*weight_proto, graph.ModelPath()); + const float* weight_data = initializer->data(); + + // Quantize float32 weight to int8_t (per-tensor, symmetric). + // int8_t quantization of input[1] works with input[0] of all types. + float scale; + int8_t zp; + GetQuantizationParameter(weight_data, static_cast(initializer->size()), scale, zp, nullptr); + + // Weight scale initializer. + ONNX_NAMESPACE::TensorProto weight_scale_proto; + weight_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_scale")); + weight_scale_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + weight_scale_proto.mutable_float_data()->Add(scale); + weight_scale_arg = &graph_utils::AddInitializer(graph, weight_scale_proto); + + // Weight zero point initializer. + ONNX_NAMESPACE::TensorProto weight_zp_proto; + weight_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_zp")); + weight_zp_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8); + weight_zp_proto.mutable_int32_data()->Add(static_cast(zp)); + NodeArg& weight_zp_arg = graph_utils::AddInitializer(graph, weight_zp_proto); + + // Q from float32 to int8. + ONNX_NAMESPACE::TypeProto weight_q_type_proto; + weight_q_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT8); + *weight_q_type_proto.mutable_tensor_type()->mutable_shape() = *weight_arg->Shape(); + NodeArg& weight_q_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_weight_q"), &weight_q_type_proto); + Node& weight_q_node = graph.AddNode( + graph.GenerateNodeArgName(node.Name() + "_weight_q"), QDQ::QOpName, "Weight Q node", + {node.MutableInputDefs()[1], weight_scale_arg, &weight_zp_arg}, {&weight_q_arg}, nullptr, node.Domain()); + + // DQ from int8 to float32. + NodeArg& weight_dq_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_weight_dq"), weight_arg->TypeAsProto()); + Node& weight_dq_node = + graph.AddNode(graph.GenerateNodeArgName(node.Name() + "_weight_dq"), QDQ::DQOpName, "Weight DQ node", + {&weight_q_arg, weight_scale_arg, &weight_zp_arg}, {&weight_dq_arg}, nullptr, node.Domain()); + graph.AddEdge(weight_q_node.Index(), weight_dq_node.Index(), 0, 0); + node.MutableInputDefs()[1] = &weight_dq_arg; + graph.AddEdge(weight_dq_node.Index(), node.Index(), 0, 1); + } else { + weight_scale_arg = dq_1->MutableInputDefs()[1]; + } + + if (bias_size != -1) { + // Bias is quantized to int32. Q cannot support int32 as target type, need to compose the whole computation. + // bias_scale = input_scale * weight_scale. + NodeArg& bias_scale_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"), + weight_scale_arg->TypeAsProto()); + Node& mul_node = + graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Bias scale node", + {dq_0.MutableInputDefs()[1], weight_scale_arg}, {&bias_scale_arg}, nullptr, node.Domain()); + + // fp_bias / scale. + NodeArg& bias_div_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), bias_arg->TypeAsProto()); + Node& div_node = + graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node", + {node.MutableInputDefs()[2], &bias_scale_arg}, {&bias_div_arg}, nullptr, node.Domain()); + graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1); + + // Round(fp_bias / scale). + NodeArg& bias_div_round_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), bias_arg->TypeAsProto()); + Node& round_node = + graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node", + {&bias_div_arg}, {&bias_div_round_arg}, nullptr, node.Domain()); + graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0); + + // Cast(Round(fp_bias / scale)) to int32. + ONNX_NAMESPACE::TypeProto bias_int32_type_proto; + bias_int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + *bias_int32_type_proto.mutable_tensor_type()->mutable_shape() = *bias_arg->Shape(); + NodeArg& bias_int32_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &bias_int32_type_proto); + Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias INT32 node", + {&bias_div_round_arg}, {&bias_int32_arg}, nullptr, node.Domain()); + cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_INT32)); + graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0); + + // Bias DQ node produces output to Conv/Gemm node's input_2, with scale = input_scale_0 * input_scale_1, zp = 0. + NodeArg& bias_dq_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), bias_arg->TypeAsProto()); + Node& bias_dq_node = + graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node", + {&bias_int32_arg, &bias_scale_arg}, {&bias_dq_arg}, nullptr, node.Domain()); + if (!is_per_tensor_scale) { + bias_dq_node.AddAttribute("axis", static_cast(0)); + } + + graph.AddEdge(cast_node.Index(), bias_dq_node.Index(), 0, 0); + graph.AddEdge(mul_node.Index(), bias_dq_node.Index(), 0, 1); + node.MutableInputDefs()[2] = &bias_dq_arg; + graph.AddEdge(bias_dq_node.Index(), node.Index(), 0, 2); + } + + modified = true; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.h b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.h new file mode 100644 index 0000000000..0417fbc381 --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** + * @class WeightBiasQuantization + * + * Some quantized models do not have Gemm/Conv/ConvTranspose's weight and/or bias quantized. This optimization adds + * subgraphs with Q->DQ after weight and/or bias. It's possible that the ConstantFolding optimizer would fold the Q Op + * so that weight and/or bias initializers are folded to initializers in target data types, followed by DQ Op. + * For weight, the Q output is a symmetric per-tensor INT8 tensor. + * For bias, the Q's scale = scale_input_0 * scale_input_1 and zero_point = (INT32)0. + */ +class WeightBiasQuantization : public GraphTransformer { + public: + WeightBiasQuantization() noexcept : GraphTransformer("WeightBiasQuantization") {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 043b92d7ef..53e66f5e9a 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -11,7 +11,7 @@ #include "core/graph/onnx_protobuf.h" #include "core/mlas/inc/mlas.h" #include "core/optimizer/double_qdq_pairs_remover.h" -#include "core/optimizer/qdq_transformer/bias_quantization.h" +#include "core/optimizer/qdq_transformer/weight_bias_quantization.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -4848,7 +4848,7 @@ TEST(QDQTransformerTests, DropDQSelectorWithDQProducingGraphOutput) { } #endif // !defined(DISABLE_CONTRIB_OPS) -TEST(QDQTransformerTests, BiasQuantization_Conv) { +TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Bias) { auto test_case = [](bool use_contrib_qdq) { auto build_test_case = [&](ModelTestBuilder& builder) { NodeArg* input_arg = builder.MakeInput({1, 24, 128, 128}, std::numeric_limits::min(), @@ -4896,7 +4896,93 @@ TEST(QDQTransformerTests, BiasQuantization_Conv) { #endif } -TEST(QDQTransformerTests, BiasQuantization_Gemm) { +TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Weight_Bias) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = builder.MakeInput({1, 24, 67, 67}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({24, 1, 5, 5}, -0.1f, 0.1f); + NodeArg* bias_arg = builder.MakeInitializer({24}, -0.1f, 0.1f); + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* conv_dq_arg = builder.MakeIntermediate(); + NodeArg* output_arg = builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.014f, static_cast(127), input_dq_arg, + use_contrib_qdq); + auto& conv_node = builder.AddNode("Conv", {input_dq_arg, weight_arg, bias_arg}, {conv_dq_arg}); + conv_node.AddAttribute("dilations", std::vector{1, 1}); + conv_node.AddAttribute("kernel_shape", std::vector{5, 5}); + conv_node.AddAttribute("strides", std::vector{2, 2}); + conv_node.AddAttribute("group", static_cast(24)); + conv_node.AddAttribute("pads", std::vector{0, 0, 0, 0}); + builder.AddQuantizeLinearNode(conv_dq_arg, 0.014f, static_cast(127), output_arg, + use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["QLinearConv"], 1); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18); + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19); + }; + + test_case(false); +#if !defined(DISABLE_CONTRIB_OPS) + test_case(true); +#endif +} + +TEST(QDQTransformerTests, WeightBiasQuantization_ConvTranspose_Weight) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = builder.MakeInput({1, 3, 4, 4}, std::numeric_limits::min(), + std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({3, 3, 3, 3}, -0.1f, 0.1f); + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* conv_dq_arg = builder.MakeIntermediate(); + NodeArg* output_arg = builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.014f, static_cast(127), input_dq_arg, + use_contrib_qdq); + auto& conv_node = builder.AddNode("ConvTranspose", {input_dq_arg, weight_arg}, {conv_dq_arg}); + conv_node.AddAttribute("dilations", std::vector{1, 1}); + conv_node.AddAttribute("kernel_shape", std::vector{3, 3}); + conv_node.AddAttribute("strides", std::vector{1, 1}); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0, 0, 0}); + builder.AddQuantizeLinearNode(conv_dq_arg, 0.014f, static_cast(127), output_arg, + use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + // No QLinearConvTranspose CPU kernel. Check the graph only. + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1); + EXPECT_EQ(op_to_count["DequantizeLinear"] + op_to_count["com.microsoft.DequantizeLinear"], 2); + EXPECT_EQ(op_to_count["ConvTranspose"], 1); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18); + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19); + }; + + test_case(false); +#if !defined(DISABLE_CONTRIB_OPS) + test_case(true); +#endif +} + +#if !defined(DISABLE_CONTRIB_OPS) + +TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Bias) { auto test_case = [](bool use_contrib_qdq) { auto build_test_case = [&](ModelTestBuilder& builder) { NodeArg* input_arg = @@ -4933,10 +5019,80 @@ TEST(QDQTransformerTests, BiasQuantization_Gemm) { }; test_case(false); -#if !defined(DISABLE_CONTRIB_OPS) test_case(true); -#endif } +TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = + builder.MakeInput({1, 32}, std::numeric_limits::min(), std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({32, 16}, -0.1f, 0.1f); + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* gemm_dq_arg = builder.MakeIntermediate(); + NodeArg* output_arg = builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.001f, static_cast(127), input_dq_arg, + use_contrib_qdq); + builder.AddNode("Gemm", {input_dq_arg, weight_arg}, {gemm_dq_arg}); + builder.AddQuantizeLinearNode(gemm_dq_arg, 0.144f, static_cast(127), output_arg, + use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18); + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19); + }; + + test_case(false); + test_case(true); +} + +TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight_Bias) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_arg = + builder.MakeInput({1, 32}, std::numeric_limits::min(), std::numeric_limits::max()); + NodeArg* weight_arg = builder.MakeInitializer({16, 32}, -0.1f, 0.1f); + NodeArg* bias_arg = builder.MakeInitializer({16}, -0.1f, 0.1f); + NodeArg* input_dq_arg = builder.MakeIntermediate(); + NodeArg* gemm_dq_arg = builder.MakeIntermediate(); + NodeArg* output_arg = builder.MakeOutput(); + + builder.AddDequantizeLinearNode(input_arg, 0.001f, static_cast(127), input_dq_arg, + use_contrib_qdq); + auto& gemm_node = builder.AddNode("Gemm", {input_dq_arg, weight_arg, bias_arg}, {gemm_dq_arg}); + gemm_node.AddAttribute("transB", static_cast(1)); + builder.AddQuantizeLinearNode(gemm_dq_arg, 0.144f, static_cast(127), output_arg, + use_contrib_qdq); + }; + + auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18); + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19); + }; + + test_case(false); + test_case(true); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) + } // namespace test } // namespace onnxruntime