Quantize Weight for Gemm/Conv on Quantized Model (#22969)

Some quantized models have QDQ around Conv/Gemm but the weight and/or
bias are not quantized. This PR adds WeightBiasQuantization optimizer to
quantize float weight and/or bias to INT8 and INT32 tensors
respectively. We only do this for weight and/or bias initializer so that
ConstantFolding will fold the sub-graph to real quantized initializers
during the graph optimization next round.
This commit is contained in:
Vincent Wang 2025-01-08 10:00:24 +08:00 committed by GitHub
parent c75681a404
commit ff0ab0a8a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 404 additions and 183 deletions

View file

@ -63,7 +63,7 @@
#ifdef MLAS_TARGET_AMD64_IX86
#include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h"
#endif
#include "core/optimizer/qdq_transformer/bias_quantization.h"
#include "core/optimizer/qdq_transformer/weight_bias_quantization.h"
#include "core/optimizer/qdq_transformer/clip_quantizelinear.h"
#include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h"
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
@ -245,7 +245,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQPropagationTransformer>());
transformers.emplace_back(std::make_unique<BiasQuantization>());
transformers.emplace_back(std::make_unique<WeightBiasQuantization>());
// EnsureUniqueDQForNodeUnit is actually a required graph transformation. The unique DQ per QDQ node unit input
// condition that it ensures is important for the partitioning that happens after Level1 optimizers are run.

View file

@ -1,149 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/qdq_transformer/bias_quantization.h"
#include "core/common/common.h"
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/qdq_transformer/qdq_util.h"
namespace onnxruntime {
Status BiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
const GraphViewer graph_viewer{graph};
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_idx : node_indices) {
auto* node_ptr = graph.GetNode(node_idx);
if (!node_ptr) {
continue;
}
Node& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
const auto& input_defs = node.InputDefs();
// It's Conv/Gemm node with an initializer bias.
if ((node.OpType() != "Conv" && node.OpType() != "Gemm") || input_defs.size() < 3 || !input_defs[2]->Exists() ||
!graph_utils::IsInitializer(graph, input_defs[2]->Name(), true)) {
continue;
}
auto bias_shape = input_defs[2]->Shape();
if (!bias_shape || bias_shape->dim_size() != 1) {
continue;
}
int64_t bias_size = bias_shape->dim(0).dim_value();
// input_0 and input_1 are outputs of DequantizeLinear nodes.
const Node* parent_node_0 = graph.GetProducerNode(input_defs[0]->Name());
const Node* parent_node_1 = graph.GetProducerNode(input_defs[1]->Name());
if (!parent_node_0 || !parent_node_1 || parent_node_0->OpType() != QDQ::DQOpName ||
parent_node_1->OpType() != QDQ::DQOpName) {
continue;
}
Node& dq_0 = *graph.GetNode(parent_node_0->Index());
Node& dq_1 = *graph.GetNode(parent_node_1->Index());
// Currently we require input_0 is per-tensor scale.
if (!optimizer_utils::IsScalar(*dq_0.InputDefs()[1])) {
continue;
}
// For input_1, it's either per-tensor scale or per-channel scale on specific axis (0 for Conv and 1 for Gemm).
bool is_per_tensor_scale = true;
if (!optimizer_utils::IsScalar(*dq_1.InputDefs()[1])) {
is_per_tensor_scale = false;
auto weight_scale_shape = dq_1.InputDefs()[1]->Shape();
if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() ||
weight_scale_shape->dim(0).dim_value() != bias_size) {
continue;
}
const auto& dq_attrs = dq_1.GetAttributes();
if (dq_attrs.find("block_size") != dq_attrs.end()) {
continue;
}
int64_t axis = 1;
if (dq_attrs.find("axis") != dq_attrs.end()) {
axis = dq_attrs.at("axis").i();
}
int64_t expected_axis = 0;
if (node.OpType() == "Gemm") {
int64_t transB = 0;
if (const auto& attr = node.GetAttributes().find("transB"); attr != node.GetAttributes().end()) {
transB = attr->second.i();
}
expected_axis = transB == 0 ? 1 : 0;
}
if (axis != expected_axis) {
continue;
}
}
// Bias is quantized to int32.
ONNX_NAMESPACE::TypeProto int32_type_proto;
int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
auto scale_type = dq_1.InputDefs()[1]->TypeAsProto(); // Maybe per-tensor (scalar) or per-channel (1D) scale.
ONNX_NAMESPACE::TypeProto bias_dq_type;
bias_dq_type.mutable_tensor_type()->set_elem_type(scale_type->tensor_type().elem_type());
bias_dq_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(bias_size);
// scale = input_scale_0 * input_scale_1.
NodeArg& scale_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"), scale_type);
Node& mul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Scale node",
{dq_0.MutableInputDefs()[1], dq_1.MutableInputDefs()[1]}, {&scale_node_arg}, nullptr,
node.Domain());
// fp_bias / scale.
NodeArg& bias_div_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), &bias_dq_type);
Node& div_node =
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node",
{node.MutableInputDefs()[2], &scale_node_arg}, {&bias_div_node_arg}, nullptr, node.Domain());
graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1);
// Round(fp_bias / scale).
NodeArg& bias_div_round_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), &bias_dq_type);
Node& round_node =
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node",
{&bias_div_node_arg}, {&bias_div_round_node_arg}, nullptr, node.Domain());
graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0);
// Cast(round(fp_bias / scale)) to int32.
NodeArg& bias_int32_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &int32_type_proto);
Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias int32 node",
{&bias_div_round_node_arg}, {&bias_int32_node_arg}, nullptr, node.Domain());
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT32));
graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0);
// Bias DQ node produces output to Conv/Gemm node's input_2, with scale = input_scale_0 * input_scale_1, zp = 0.
NodeArg& bias_dq_node_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), &bias_dq_type);
Node& dq_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node",
{&bias_int32_node_arg, &scale_node_arg}, {&bias_dq_node_arg}, nullptr, node.Domain());
if (!is_per_tensor_scale) {
dq_node.AddAttribute("axis", static_cast<int64_t>(0));
}
graph.AddEdge(cast_node.Index(), dq_node.Index(), 0, 0);
graph.AddEdge(mul_node.Index(), dq_node.Index(), 0, 1);
node.MutableInputDefs()[2] = &bias_dq_node_arg;
graph.AddEdge(dq_node.Index(), node.Index(), 0, 2);
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -1,27 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
/**
* @class BiasQuantization
*
* Some quantized models do not have Gemm/Conv's bias quantized. This optimization adds a subgraph to quantize the bias
* with scale = scale_input_0 * scale_input_1 and zero_point = 0.
*
* Normally the ConstantFolding optimizer would fold the bias initializer into an int32_t initializer, which is consumed
* by a DequantizeLinear node.
*/
class BiasQuantization : public GraphTransformer {
public:
BiasQuantization() noexcept : GraphTransformer("BiasQuantization") {}
private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -0,0 +1,214 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/qdq_transformer/weight_bias_quantization.h"
#include "core/common/common.h"
#include "core/util/qmath.h"
#include "core/graph/graph_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/qdq_transformer/qdq_util.h"
namespace onnxruntime {
Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
const GraphViewer graph_viewer{graph};
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (const auto node_idx : node_indices) {
auto* node_ptr = graph.GetNode(node_idx);
if (!node_ptr) {
continue;
}
Node& node = *node_ptr;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (node.OpType() != "Conv" && node.OpType() != "ConvTranspose" && node.OpType() != "Gemm") {
continue;
}
const auto& input_defs = node.InputDefs();
const NodeArg* input_arg = input_defs[0];
const NodeArg* weight_arg = input_defs[1];
const NodeArg* bias_arg = input_defs.size() >= 3 && input_defs[2]->Exists() ? input_defs[2] : nullptr;
const Node* parent_node_0 = graph.GetProducerNode(input_arg->Name());
const Node* parent_node_1 = graph.GetProducerNode(weight_arg->Name());
// Currently we require input is Dequantized with per-tensor scale.
if (!parent_node_0 || parent_node_0->OpType() != QDQ::DQOpName ||
!optimizer_utils::IsScalar(*parent_node_0->InputDefs()[1])) {
continue;
}
Node& dq_0 = *graph.GetNode(parent_node_0->Index());
Node* dq_1 = nullptr;
const ONNX_NAMESPACE::TensorProto* weight_proto = nullptr;
if (parent_node_1 && parent_node_1->OpType() == QDQ::DQOpName) {
dq_1 = graph.GetNode(parent_node_1->Index());
} else if (!graph_utils::IsInitializer(graph, weight_arg->Name(), true) ||
!graph.GetInitializedTensor(weight_arg->Name(), weight_proto) ||
weight_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
// Support float32 weight initializer only for now.
continue;
}
int64_t bias_size = -1;
if (bias_arg) {
auto bias_shape = bias_arg->Shape();
if (!graph_utils::IsInitializer(graph, bias_arg->Name(), true) || !bias_shape || bias_shape->dim_size() != 1) {
continue;
}
bias_size = bias_shape->dim(0).dim_value();
}
// Nothing to do if neither weight nor bias is initializer.
if (dq_1 && bias_size == -1) {
continue;
}
bool is_per_tensor_scale = true;
// If weight is quantized, it's either per-tensor or per-channel on specific axis (0 for Conv and 1 for Gemm).
if (dq_1 && !optimizer_utils::IsScalar(*dq_1->InputDefs()[1])) {
is_per_tensor_scale = false;
auto weight_scale_shape = dq_1->InputDefs()[1]->Shape();
if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() ||
weight_scale_shape->dim(0).dim_value() != bias_size) {
continue;
}
const auto& dq_attrs = dq_1->GetAttributes();
if (dq_attrs.find("block_size") != dq_attrs.end()) {
continue;
}
int64_t axis = 1;
if (auto axis_iter = dq_attrs.find("axis"); axis_iter != dq_attrs.end()) {
axis = axis_iter->second.i();
}
int64_t expected_axis = 0;
if (node.OpType() == "Gemm") {
int64_t transB = 0;
const auto& gemm_attrs = node.GetAttributes();
if (auto trans_b_iter = gemm_attrs.find("transB"); trans_b_iter != gemm_attrs.end()) {
transB = trans_b_iter->second.i();
}
expected_axis = transB == 0 ? 1 : 0;
}
if (axis != expected_axis) {
continue;
}
}
NodeArg* weight_scale_arg = nullptr;
if (!dq_1) {
auto initializer = std::make_unique<Initializer>(*weight_proto, graph.ModelPath());
const float* weight_data = initializer->data<float>();
// Quantize float32 weight to int8_t (per-tensor, symmetric).
// int8_t quantization of input[1] works with input[0] of all types.
float scale;
int8_t zp;
GetQuantizationParameter(weight_data, static_cast<int64_t>(initializer->size()), scale, zp, nullptr);
// Weight scale initializer.
ONNX_NAMESPACE::TensorProto weight_scale_proto;
weight_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_scale"));
weight_scale_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
weight_scale_proto.mutable_float_data()->Add(scale);
weight_scale_arg = &graph_utils::AddInitializer(graph, weight_scale_proto);
// Weight zero point initializer.
ONNX_NAMESPACE::TensorProto weight_zp_proto;
weight_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_zp"));
weight_zp_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8);
weight_zp_proto.mutable_int32_data()->Add(static_cast<int32_t>(zp));
NodeArg& weight_zp_arg = graph_utils::AddInitializer(graph, weight_zp_proto);
// Q from float32 to int8.
ONNX_NAMESPACE::TypeProto weight_q_type_proto;
weight_q_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT8);
*weight_q_type_proto.mutable_tensor_type()->mutable_shape() = *weight_arg->Shape();
NodeArg& weight_q_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_weight_q"), &weight_q_type_proto);
Node& weight_q_node = graph.AddNode(
graph.GenerateNodeArgName(node.Name() + "_weight_q"), QDQ::QOpName, "Weight Q node",
{node.MutableInputDefs()[1], weight_scale_arg, &weight_zp_arg}, {&weight_q_arg}, nullptr, node.Domain());
// DQ from int8 to float32.
NodeArg& weight_dq_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_weight_dq"), weight_arg->TypeAsProto());
Node& weight_dq_node =
graph.AddNode(graph.GenerateNodeArgName(node.Name() + "_weight_dq"), QDQ::DQOpName, "Weight DQ node",
{&weight_q_arg, weight_scale_arg, &weight_zp_arg}, {&weight_dq_arg}, nullptr, node.Domain());
graph.AddEdge(weight_q_node.Index(), weight_dq_node.Index(), 0, 0);
node.MutableInputDefs()[1] = &weight_dq_arg;
graph.AddEdge(weight_dq_node.Index(), node.Index(), 0, 1);
} else {
weight_scale_arg = dq_1->MutableInputDefs()[1];
}
if (bias_size != -1) {
// Bias is quantized to int32. Q cannot support int32 as target type, need to compose the whole computation.
// bias_scale = input_scale * weight_scale.
NodeArg& bias_scale_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"),
weight_scale_arg->TypeAsProto());
Node& mul_node =
graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Bias scale node",
{dq_0.MutableInputDefs()[1], weight_scale_arg}, {&bias_scale_arg}, nullptr, node.Domain());
// fp_bias / scale.
NodeArg& bias_div_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), bias_arg->TypeAsProto());
Node& div_node =
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node",
{node.MutableInputDefs()[2], &bias_scale_arg}, {&bias_div_arg}, nullptr, node.Domain());
graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1);
// Round(fp_bias / scale).
NodeArg& bias_div_round_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), bias_arg->TypeAsProto());
Node& round_node =
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node",
{&bias_div_arg}, {&bias_div_round_arg}, nullptr, node.Domain());
graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0);
// Cast(Round(fp_bias / scale)) to int32.
ONNX_NAMESPACE::TypeProto bias_int32_type_proto;
bias_int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
*bias_int32_type_proto.mutable_tensor_type()->mutable_shape() = *bias_arg->Shape();
NodeArg& bias_int32_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &bias_int32_type_proto);
Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias INT32 node",
{&bias_div_round_arg}, {&bias_int32_arg}, nullptr, node.Domain());
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT32));
graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0);
// Bias DQ node produces output to Conv/Gemm node's input_2, with scale = input_scale_0 * input_scale_1, zp = 0.
NodeArg& bias_dq_arg =
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), bias_arg->TypeAsProto());
Node& bias_dq_node =
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node",
{&bias_int32_arg, &bias_scale_arg}, {&bias_dq_arg}, nullptr, node.Domain());
if (!is_per_tensor_scale) {
bias_dq_node.AddAttribute("axis", static_cast<int64_t>(0));
}
graph.AddEdge(cast_node.Index(), bias_dq_node.Index(), 0, 0);
graph.AddEdge(mul_node.Index(), bias_dq_node.Index(), 0, 1);
node.MutableInputDefs()[2] = &bias_dq_arg;
graph.AddEdge(bias_dq_node.Index(), node.Index(), 0, 2);
}
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
/**
* @class WeightBiasQuantization
*
* Some quantized models do not have Gemm/Conv/ConvTranspose's weight and/or bias quantized. This optimization adds
* subgraphs with Q->DQ after weight and/or bias. It's possible that the ConstantFolding optimizer would fold the Q Op
* so that weight and/or bias initializers are folded to initializers in target data types, followed by DQ Op.
* For weight, the Q output is a symmetric per-tensor INT8 tensor.
* For bias, the Q's scale = scale_input_0 * scale_input_1 and zero_point = (INT32)0.
*/
class WeightBiasQuantization : public GraphTransformer {
public:
WeightBiasQuantization() noexcept : GraphTransformer("WeightBiasQuantization") {}
private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -11,7 +11,7 @@
#include "core/graph/onnx_protobuf.h"
#include "core/mlas/inc/mlas.h"
#include "core/optimizer/double_qdq_pairs_remover.h"
#include "core/optimizer/qdq_transformer/bias_quantization.h"
#include "core/optimizer/qdq_transformer/weight_bias_quantization.h"
#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h"
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
@ -4848,7 +4848,7 @@ TEST(QDQTransformerTests, DropDQSelectorWithDQProducingGraphOutput) {
}
#endif // !defined(DISABLE_CONTRIB_OPS)
TEST(QDQTransformerTests, BiasQuantization_Conv) {
TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Bias) {
auto test_case = [](bool use_contrib_qdq) {
auto build_test_case = [&](ModelTestBuilder& builder) {
NodeArg* input_arg = builder.MakeInput<uint8_t>({1, 24, 128, 128}, std::numeric_limits<uint8_t>::min(),
@ -4896,7 +4896,93 @@ TEST(QDQTransformerTests, BiasQuantization_Conv) {
#endif
}
TEST(QDQTransformerTests, BiasQuantization_Gemm) {
TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Weight_Bias) {
auto test_case = [](bool use_contrib_qdq) {
auto build_test_case = [&](ModelTestBuilder& builder) {
NodeArg* input_arg = builder.MakeInput<uint8_t>({1, 24, 67, 67}, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
NodeArg* weight_arg = builder.MakeInitializer<float>({24, 1, 5, 5}, -0.1f, 0.1f);
NodeArg* bias_arg = builder.MakeInitializer<float>({24}, -0.1f, 0.1f);
NodeArg* input_dq_arg = builder.MakeIntermediate();
NodeArg* conv_dq_arg = builder.MakeIntermediate();
NodeArg* output_arg = builder.MakeOutput();
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.014f, static_cast<uint8_t>(127), input_dq_arg,
use_contrib_qdq);
auto& conv_node = builder.AddNode("Conv", {input_dq_arg, weight_arg, bias_arg}, {conv_dq_arg});
conv_node.AddAttribute("dilations", std::vector<int64_t>{1, 1});
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{5, 5});
conv_node.AddAttribute("strides", std::vector<int64_t>{2, 2});
conv_node.AddAttribute("group", static_cast<int64_t>(24));
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
builder.AddQuantizeLinearNode<uint8_t>(conv_dq_arg, 0.014f, static_cast<uint8_t>(127), output_arg,
use_contrib_qdq);
};
auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
EXPECT_EQ(op_to_count["QLinearConv"], 1);
};
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
};
test_case(false);
#if !defined(DISABLE_CONTRIB_OPS)
test_case(true);
#endif
}
TEST(QDQTransformerTests, WeightBiasQuantization_ConvTranspose_Weight) {
auto test_case = [](bool use_contrib_qdq) {
auto build_test_case = [&](ModelTestBuilder& builder) {
NodeArg* input_arg = builder.MakeInput<uint8_t>({1, 3, 4, 4}, std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
NodeArg* weight_arg = builder.MakeInitializer<float>({3, 3, 3, 3}, -0.1f, 0.1f);
NodeArg* input_dq_arg = builder.MakeIntermediate();
NodeArg* conv_dq_arg = builder.MakeIntermediate();
NodeArg* output_arg = builder.MakeOutput();
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.014f, static_cast<uint8_t>(127), input_dq_arg,
use_contrib_qdq);
auto& conv_node = builder.AddNode("ConvTranspose", {input_dq_arg, weight_arg}, {conv_dq_arg});
conv_node.AddAttribute("dilations", std::vector<int64_t>{1, 1});
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});
conv_node.AddAttribute("strides", std::vector<int64_t>{1, 1});
conv_node.AddAttribute("group", static_cast<int64_t>(1));
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
builder.AddQuantizeLinearNode<uint8_t>(conv_dq_arg, 0.014f, static_cast<uint8_t>(127), output_arg,
use_contrib_qdq);
};
auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
// No QLinearConvTranspose CPU kernel. Check the graph only.
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1);
EXPECT_EQ(op_to_count["DequantizeLinear"] + op_to_count["com.microsoft.DequantizeLinear"], 2);
EXPECT_EQ(op_to_count["ConvTranspose"], 1);
};
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
};
test_case(false);
#if !defined(DISABLE_CONTRIB_OPS)
test_case(true);
#endif
}
#if !defined(DISABLE_CONTRIB_OPS)
TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Bias) {
auto test_case = [](bool use_contrib_qdq) {
auto build_test_case = [&](ModelTestBuilder& builder) {
NodeArg* input_arg =
@ -4933,10 +5019,80 @@ TEST(QDQTransformerTests, BiasQuantization_Gemm) {
};
test_case(false);
#if !defined(DISABLE_CONTRIB_OPS)
test_case(true);
#endif
}
TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight) {
auto test_case = [](bool use_contrib_qdq) {
auto build_test_case = [&](ModelTestBuilder& builder) {
NodeArg* input_arg =
builder.MakeInput<uint8_t>({1, 32}, std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
NodeArg* weight_arg = builder.MakeInitializer<float>({32, 16}, -0.1f, 0.1f);
NodeArg* input_dq_arg = builder.MakeIntermediate();
NodeArg* gemm_dq_arg = builder.MakeIntermediate();
NodeArg* output_arg = builder.MakeOutput();
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.001f, static_cast<uint8_t>(127), input_dq_arg,
use_contrib_qdq);
builder.AddNode("Gemm", {input_dq_arg, weight_arg}, {gemm_dq_arg});
builder.AddQuantizeLinearNode<uint8_t>(gemm_dq_arg, 0.144f, static_cast<uint8_t>(127), output_arg,
use_contrib_qdq);
};
auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
};
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
};
test_case(false);
test_case(true);
}
TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight_Bias) {
auto test_case = [](bool use_contrib_qdq) {
auto build_test_case = [&](ModelTestBuilder& builder) {
NodeArg* input_arg =
builder.MakeInput<uint8_t>({1, 32}, std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
NodeArg* weight_arg = builder.MakeInitializer<float>({16, 32}, -0.1f, 0.1f);
NodeArg* bias_arg = builder.MakeInitializer<float>({16}, -0.1f, 0.1f);
NodeArg* input_dq_arg = builder.MakeIntermediate();
NodeArg* gemm_dq_arg = builder.MakeIntermediate();
NodeArg* output_arg = builder.MakeOutput();
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.001f, static_cast<uint8_t>(127), input_dq_arg,
use_contrib_qdq);
auto& gemm_node = builder.AddNode("Gemm", {input_dq_arg, weight_arg, bias_arg}, {gemm_dq_arg});
gemm_node.AddAttribute("transB", static_cast<int64_t>(1));
builder.AddQuantizeLinearNode<uint8_t>(gemm_dq_arg, 0.144f, static_cast<uint8_t>(127), output_arg,
use_contrib_qdq);
};
auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
};
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
};
test_case(false);
test_case(true);
}
#endif // !defined(DISABLE_CONTRIB_OPS)
} // namespace test
} // namespace onnxruntime