mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Quantize Bias for Conv/Gemm on Quantized Model (#22889)
Some quantized models don't have Conv/Gemm node's bias quantized but still leave them in float. This PR is to create a sub-graph to quantize the bias for Conv/Gemm nodes with scale = scale_input_0 * scale_input_1 and zp = 0. We only do this for bias initializer so that ConstantFolding will fold the sub-graph to a real quantized int32 bias initializer during the graph optimization next round.
This commit is contained in:
parent
42ecb05080
commit
1128882bfd
4 changed files with 269 additions and 0 deletions
|
|
@ -63,6 +63,7 @@
|
|||
#ifdef MLAS_TARGET_AMD64_IX86
|
||||
#include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h"
|
||||
#endif
|
||||
#include "core/optimizer/qdq_transformer/bias_quantization.h"
|
||||
#include "core/optimizer/qdq_transformer/clip_quantizelinear.h"
|
||||
#include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
|
||||
|
|
@ -243,6 +244,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
|
||||
if (!disable_quant_qdq) {
|
||||
transformers.emplace_back(std::make_unique<QDQPropagationTransformer>());
|
||||
transformers.emplace_back(std::make_unique<BiasQuantization>());
|
||||
|
||||
// EnsureUniqueDQForNodeUnit is actually a required graph transformation. The unique DQ per QDQ node unit input
|
||||
// condition that it ensures is important for the partitioning that happens after Level1 optimizers are run.
|
||||
|
|
|
|||
149
onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc
Normal file
149
onnxruntime/core/optimizer/qdq_transformer/bias_quantization.cc
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/qdq_transformer/bias_quantization.h"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status BiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
const GraphViewer graph_viewer{graph};
|
||||
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (const auto node_idx : node_indices) {
|
||||
auto* node_ptr = graph.GetNode(node_idx);
|
||||
if (!node_ptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& node = *node_ptr;
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
const auto& input_defs = node.InputDefs();
|
||||
|
||||
// It's Conv/Gemm node with an initializer bias.
|
||||
if ((node.OpType() != "Conv" && node.OpType() != "Gemm") || input_defs.size() < 3 || !input_defs[2]->Exists() ||
|
||||
!graph_utils::IsInitializer(graph, input_defs[2]->Name(), true)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto bias_shape = input_defs[2]->Shape();
|
||||
if (!bias_shape || bias_shape->dim_size() != 1) {
|
||||
continue;
|
||||
}
|
||||
int64_t bias_size = bias_shape->dim(0).dim_value();
|
||||
|
||||
// input_0 and input_1 are outputs of DequantizeLinear nodes.
|
||||
const Node* parent_node_0 = graph.GetProducerNode(input_defs[0]->Name());
|
||||
const Node* parent_node_1 = graph.GetProducerNode(input_defs[1]->Name());
|
||||
if (!parent_node_0 || !parent_node_1 || parent_node_0->OpType() != QDQ::DQOpName ||
|
||||
parent_node_1->OpType() != QDQ::DQOpName) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& dq_0 = *graph.GetNode(parent_node_0->Index());
|
||||
Node& dq_1 = *graph.GetNode(parent_node_1->Index());
|
||||
|
||||
// Currently we require input_0 is per-tensor scale.
|
||||
if (!optimizer_utils::IsScalar(*dq_0.InputDefs()[1])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// For input_1, it's either per-tensor scale or per-channel scale on specific axis (0 for Conv and 1 for Gemm).
|
||||
bool is_per_tensor_scale = true;
|
||||
if (!optimizer_utils::IsScalar(*dq_1.InputDefs()[1])) {
|
||||
is_per_tensor_scale = false;
|
||||
auto weight_scale_shape = dq_1.InputDefs()[1]->Shape();
|
||||
if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() ||
|
||||
weight_scale_shape->dim(0).dim_value() != bias_size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& dq_attrs = dq_1.GetAttributes();
|
||||
if (dq_attrs.find("block_size") != dq_attrs.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t axis = 1;
|
||||
if (dq_attrs.find("axis") != dq_attrs.end()) {
|
||||
axis = dq_attrs.at("axis").i();
|
||||
}
|
||||
|
||||
int64_t expected_axis = 0;
|
||||
if (node.OpType() == "Gemm") {
|
||||
int64_t transB = 0;
|
||||
if (const auto& attr = node.GetAttributes().find("transB"); attr != node.GetAttributes().end()) {
|
||||
transB = attr->second.i();
|
||||
}
|
||||
expected_axis = transB == 0 ? 1 : 0;
|
||||
}
|
||||
|
||||
if (axis != expected_axis) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Bias is quantized to int32.
|
||||
ONNX_NAMESPACE::TypeProto int32_type_proto;
|
||||
int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
auto scale_type = dq_1.InputDefs()[1]->TypeAsProto(); // Maybe per-tensor (scalar) or per-channel (1D) scale.
|
||||
ONNX_NAMESPACE::TypeProto bias_dq_type;
|
||||
bias_dq_type.mutable_tensor_type()->set_elem_type(scale_type->tensor_type().elem_type());
|
||||
bias_dq_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(bias_size);
|
||||
|
||||
// scale = input_scale_0 * input_scale_1.
|
||||
NodeArg& scale_node_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"), scale_type);
|
||||
Node& mul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Scale node",
|
||||
{dq_0.MutableInputDefs()[1], dq_1.MutableInputDefs()[1]}, {&scale_node_arg}, nullptr,
|
||||
node.Domain());
|
||||
|
||||
// fp_bias / scale.
|
||||
NodeArg& bias_div_node_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), &bias_dq_type);
|
||||
Node& div_node =
|
||||
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node",
|
||||
{node.MutableInputDefs()[2], &scale_node_arg}, {&bias_div_node_arg}, nullptr, node.Domain());
|
||||
graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1);
|
||||
|
||||
// Round(fp_bias / scale).
|
||||
NodeArg& bias_div_round_node_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), &bias_dq_type);
|
||||
Node& round_node =
|
||||
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node",
|
||||
{&bias_div_node_arg}, {&bias_div_round_node_arg}, nullptr, node.Domain());
|
||||
graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0);
|
||||
|
||||
// Cast(round(fp_bias / scale)) to int32.
|
||||
NodeArg& bias_int32_node_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &int32_type_proto);
|
||||
Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias int32 node",
|
||||
{&bias_div_round_node_arg}, {&bias_int32_node_arg}, nullptr, node.Domain());
|
||||
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT32));
|
||||
graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0);
|
||||
|
||||
// Bias DQ node produces output to Conv/Gemm node's input_2, with scale = input_scale_0 * input_scale_1, zp = 0.
|
||||
NodeArg& bias_dq_node_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), &bias_dq_type);
|
||||
Node& dq_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node",
|
||||
{&bias_int32_node_arg, &scale_node_arg}, {&bias_dq_node_arg}, nullptr, node.Domain());
|
||||
if (!is_per_tensor_scale) {
|
||||
dq_node.AddAttribute("axis", static_cast<int64_t>(0));
|
||||
}
|
||||
|
||||
graph.AddEdge(cast_node.Index(), dq_node.Index(), 0, 0);
|
||||
graph.AddEdge(mul_node.Index(), dq_node.Index(), 0, 1);
|
||||
node.MutableInputDefs()[2] = &bias_dq_node_arg;
|
||||
graph.AddEdge(dq_node.Index(), node.Index(), 0, 2);
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
* @class BiasQuantization
|
||||
*
|
||||
* Some quantized models do not have Gemm/Conv's bias quantized. This optimization adds a subgraph to quantize the bias
|
||||
* with scale = scale_input_0 * scale_input_1 and zero_point = 0.
|
||||
*
|
||||
* Normally the ConstantFolding optimizer would fold the bias initializer into an int32_t initializer, which is consumed
|
||||
* by a DequantizeLinear node.
|
||||
*/
|
||||
class BiasQuantization : public GraphTransformer {
|
||||
public:
|
||||
BiasQuantization() noexcept : GraphTransformer("BiasQuantization") {}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -11,6 +11,7 @@
|
|||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/optimizer/double_qdq_pairs_remover.h"
|
||||
#include "core/optimizer/qdq_transformer/bias_quantization.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
|
|
@ -4846,5 +4847,95 @@ TEST(QDQTransformerTests, DropDQSelectorWithDQProducingGraphOutput) {
|
|||
}
|
||||
#endif // !defined(DISABLE_CONTRIB_OPS)
|
||||
|
||||
TEST(QDQTransformerTests, BiasQuantization_Conv) {
|
||||
auto test_case = [](bool use_contrib_qdq) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
NodeArg* input_arg = builder.MakeInput<uint8_t>({1, 24, 128, 128}, std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
NodeArg* weight_arg = builder.MakeInitializer<uint8_t>({24, 1, 3, 3}, std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
NodeArg* bias_arg = builder.MakeInitializer<float>({24}, -0.1f, 0.1f);
|
||||
NodeArg* input_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* weight_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* conv_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* output_arg = builder.MakeOutput();
|
||||
|
||||
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.07f, static_cast<uint8_t>(0), input_dq_arg,
|
||||
use_contrib_qdq);
|
||||
auto& weight_dq_node = builder.AddDequantizeLinearNode<uint8_t>(weight_arg, std::vector<float>(24, 0.05f),
|
||||
std::vector<uint8_t>(24, static_cast<uint8_t>(0)),
|
||||
weight_dq_arg, nullptr, use_contrib_qdq);
|
||||
weight_dq_node.AddAttribute("axis", static_cast<int64_t>(0));
|
||||
auto& conv_node = builder.AddNode("Conv", {input_dq_arg, weight_dq_arg, bias_arg}, {conv_dq_arg});
|
||||
conv_node.AddAttribute("dilations", std::vector<int64_t>{1, 1});
|
||||
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});
|
||||
conv_node.AddAttribute("strides", std::vector<int64_t>{1, 1});
|
||||
conv_node.AddAttribute("group", static_cast<int64_t>(24));
|
||||
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
|
||||
builder.AddQuantizeLinearNode<uint8_t>(conv_dq_arg, 0.14f, static_cast<uint8_t>(127), output_arg,
|
||||
use_contrib_qdq);
|
||||
};
|
||||
|
||||
auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
|
||||
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
|
||||
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
|
||||
EXPECT_EQ(op_to_count["QLinearConv"], 1);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
|
||||
};
|
||||
|
||||
test_case(false);
|
||||
#if !defined(DISABLE_CONTRIB_OPS)
|
||||
test_case(true);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, BiasQuantization_Gemm) {
|
||||
auto test_case = [](bool use_contrib_qdq) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
NodeArg* input_arg =
|
||||
builder.MakeInput<uint8_t>({1, 32}, std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
|
||||
NodeArg* weight_arg = builder.MakeInitializer<uint8_t>({16, 32}, std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
NodeArg* bias_arg = builder.MakeInitializer<float>({16}, -0.1f, 0.1f);
|
||||
NodeArg* input_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* weight_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* gemm_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* output_arg = builder.MakeOutput();
|
||||
|
||||
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.001f, static_cast<uint8_t>(0), input_dq_arg,
|
||||
use_contrib_qdq);
|
||||
builder.AddDequantizeLinearNode<uint8_t>(weight_arg, 0.26f, static_cast<uint8_t>(0), weight_dq_arg,
|
||||
use_contrib_qdq);
|
||||
auto& gemm_node = builder.AddNode("Gemm", {input_dq_arg, weight_dq_arg, bias_arg}, {gemm_dq_arg});
|
||||
gemm_node.AddAttribute("transB", static_cast<int64_t>(1));
|
||||
builder.AddQuantizeLinearNode<uint8_t>(gemm_dq_arg, 0.144f, static_cast<uint8_t>(69), output_arg,
|
||||
use_contrib_qdq);
|
||||
};
|
||||
|
||||
auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
|
||||
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
|
||||
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
|
||||
};
|
||||
|
||||
test_case(false);
|
||||
#if !defined(DISABLE_CONTRIB_OPS)
|
||||
test_case(true);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue