mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
Quantize Weight for Gemm/Conv on Quantized Model (#22969)
Some quantized models have QDQ around Conv/Gemm but the weight and/or bias are not quantized. This PR adds WeightBiasQuantization optimizer to quantize float weight and/or bias to INT8 and INT32 tensors respectively. We only do this for weight and/or bias initializer so that ConstantFolding will fold the sub-graph to real quantized initializers during the graph optimization next round.
This commit is contained in:
parent
c75681a404
commit
ff0ab0a8a5
6 changed files with 404 additions and 183 deletions
|
|
@ -63,7 +63,7 @@
|
|||
#ifdef MLAS_TARGET_AMD64_IX86
|
||||
#include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h"
|
||||
#endif
|
||||
#include "core/optimizer/qdq_transformer/bias_quantization.h"
|
||||
#include "core/optimizer/qdq_transformer/weight_bias_quantization.h"
|
||||
#include "core/optimizer/qdq_transformer/clip_quantizelinear.h"
|
||||
#include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
|
||||
|
|
@ -245,7 +245,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
|
||||
if (!disable_quant_qdq) {
|
||||
transformers.emplace_back(std::make_unique<QDQPropagationTransformer>());
|
||||
transformers.emplace_back(std::make_unique<BiasQuantization>());
|
||||
transformers.emplace_back(std::make_unique<WeightBiasQuantization>());
|
||||
|
||||
// EnsureUniqueDQForNodeUnit is actually a required graph transformation. The unique DQ per QDQ node unit input
|
||||
// condition that it ensures is important for the partitioning that happens after Level1 optimizers are run.
|
||||
|
|
|
|||
|
|
@ -1,149 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/qdq_transformer/bias_quantization.h"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status BiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
const GraphViewer graph_viewer{graph};
|
||||
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (const auto node_idx : node_indices) {
|
||||
auto* node_ptr = graph.GetNode(node_idx);
|
||||
if (!node_ptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& node = *node_ptr;
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
const auto& input_defs = node.InputDefs();
|
||||
|
||||
// It's Conv/Gemm node with an initializer bias.
|
||||
if ((node.OpType() != "Conv" && node.OpType() != "Gemm") || input_defs.size() < 3 || !input_defs[2]->Exists() ||
|
||||
!graph_utils::IsInitializer(graph, input_defs[2]->Name(), true)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto bias_shape = input_defs[2]->Shape();
|
||||
if (!bias_shape || bias_shape->dim_size() != 1) {
|
||||
continue;
|
||||
}
|
||||
int64_t bias_size = bias_shape->dim(0).dim_value();
|
||||
|
||||
// input_0 and input_1 are outputs of DequantizeLinear nodes.
|
||||
const Node* parent_node_0 = graph.GetProducerNode(input_defs[0]->Name());
|
||||
const Node* parent_node_1 = graph.GetProducerNode(input_defs[1]->Name());
|
||||
if (!parent_node_0 || !parent_node_1 || parent_node_0->OpType() != QDQ::DQOpName ||
|
||||
parent_node_1->OpType() != QDQ::DQOpName) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& dq_0 = *graph.GetNode(parent_node_0->Index());
|
||||
Node& dq_1 = *graph.GetNode(parent_node_1->Index());
|
||||
|
||||
// Currently we require input_0 is per-tensor scale.
|
||||
if (!optimizer_utils::IsScalar(*dq_0.InputDefs()[1])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// For input_1, it's either per-tensor scale or per-channel scale on specific axis (0 for Conv and 1 for Gemm).
|
||||
bool is_per_tensor_scale = true;
|
||||
if (!optimizer_utils::IsScalar(*dq_1.InputDefs()[1])) {
|
||||
is_per_tensor_scale = false;
|
||||
auto weight_scale_shape = dq_1.InputDefs()[1]->Shape();
|
||||
if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() ||
|
||||
weight_scale_shape->dim(0).dim_value() != bias_size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& dq_attrs = dq_1.GetAttributes();
|
||||
if (dq_attrs.find("block_size") != dq_attrs.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t axis = 1;
|
||||
if (dq_attrs.find("axis") != dq_attrs.end()) {
|
||||
axis = dq_attrs.at("axis").i();
|
||||
}
|
||||
|
||||
int64_t expected_axis = 0;
|
||||
if (node.OpType() == "Gemm") {
|
||||
int64_t transB = 0;
|
||||
if (const auto& attr = node.GetAttributes().find("transB"); attr != node.GetAttributes().end()) {
|
||||
transB = attr->second.i();
|
||||
}
|
||||
expected_axis = transB == 0 ? 1 : 0;
|
||||
}
|
||||
|
||||
if (axis != expected_axis) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Bias is quantized to int32.
|
||||
ONNX_NAMESPACE::TypeProto int32_type_proto;
|
||||
int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
auto scale_type = dq_1.InputDefs()[1]->TypeAsProto(); // Maybe per-tensor (scalar) or per-channel (1D) scale.
|
||||
ONNX_NAMESPACE::TypeProto bias_dq_type;
|
||||
bias_dq_type.mutable_tensor_type()->set_elem_type(scale_type->tensor_type().elem_type());
|
||||
bias_dq_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(bias_size);
|
||||
|
||||
// scale = input_scale_0 * input_scale_1.
|
||||
NodeArg& scale_node_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"), scale_type);
|
||||
Node& mul_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Scale node",
|
||||
{dq_0.MutableInputDefs()[1], dq_1.MutableInputDefs()[1]}, {&scale_node_arg}, nullptr,
|
||||
node.Domain());
|
||||
|
||||
// fp_bias / scale.
|
||||
NodeArg& bias_div_node_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), &bias_dq_type);
|
||||
Node& div_node =
|
||||
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node",
|
||||
{node.MutableInputDefs()[2], &scale_node_arg}, {&bias_div_node_arg}, nullptr, node.Domain());
|
||||
graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1);
|
||||
|
||||
// Round(fp_bias / scale).
|
||||
NodeArg& bias_div_round_node_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), &bias_dq_type);
|
||||
Node& round_node =
|
||||
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node",
|
||||
{&bias_div_node_arg}, {&bias_div_round_node_arg}, nullptr, node.Domain());
|
||||
graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0);
|
||||
|
||||
// Cast(round(fp_bias / scale)) to int32.
|
||||
NodeArg& bias_int32_node_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &int32_type_proto);
|
||||
Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias int32 node",
|
||||
{&bias_div_round_node_arg}, {&bias_int32_node_arg}, nullptr, node.Domain());
|
||||
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT32));
|
||||
graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0);
|
||||
|
||||
// Bias DQ node produces output to Conv/Gemm node's input_2, with scale = input_scale_0 * input_scale_1, zp = 0.
|
||||
NodeArg& bias_dq_node_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), &bias_dq_type);
|
||||
Node& dq_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node",
|
||||
{&bias_int32_node_arg, &scale_node_arg}, {&bias_dq_node_arg}, nullptr, node.Domain());
|
||||
if (!is_per_tensor_scale) {
|
||||
dq_node.AddAttribute("axis", static_cast<int64_t>(0));
|
||||
}
|
||||
|
||||
graph.AddEdge(cast_node.Index(), dq_node.Index(), 0, 0);
|
||||
graph.AddEdge(mul_node.Index(), dq_node.Index(), 0, 1);
|
||||
node.MutableInputDefs()[2] = &bias_dq_node_arg;
|
||||
graph.AddEdge(dq_node.Index(), node.Index(), 0, 2);
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
* @class BiasQuantization
|
||||
*
|
||||
* Some quantized models do not have Gemm/Conv's bias quantized. This optimization adds a subgraph to quantize the bias
|
||||
* with scale = scale_input_0 * scale_input_1 and zero_point = 0.
|
||||
*
|
||||
* Normally the ConstantFolding optimizer would fold the bias initializer into an int32_t initializer, which is consumed
|
||||
* by a DequantizeLinear node.
|
||||
*/
|
||||
class BiasQuantization : public GraphTransformer {
|
||||
public:
|
||||
BiasQuantization() noexcept : GraphTransformer("BiasQuantization") {}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,214 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/qdq_transformer/weight_bias_quantization.h"
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/util/qmath.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_util.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger) const {
|
||||
const GraphViewer graph_viewer{graph};
|
||||
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (const auto node_idx : node_indices) {
|
||||
auto* node_ptr = graph.GetNode(node_idx);
|
||||
if (!node_ptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& node = *node_ptr;
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (node.OpType() != "Conv" && node.OpType() != "ConvTranspose" && node.OpType() != "Gemm") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const NodeArg* input_arg = input_defs[0];
|
||||
const NodeArg* weight_arg = input_defs[1];
|
||||
const NodeArg* bias_arg = input_defs.size() >= 3 && input_defs[2]->Exists() ? input_defs[2] : nullptr;
|
||||
const Node* parent_node_0 = graph.GetProducerNode(input_arg->Name());
|
||||
const Node* parent_node_1 = graph.GetProducerNode(weight_arg->Name());
|
||||
|
||||
// Currently we require input is Dequantized with per-tensor scale.
|
||||
if (!parent_node_0 || parent_node_0->OpType() != QDQ::DQOpName ||
|
||||
!optimizer_utils::IsScalar(*parent_node_0->InputDefs()[1])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& dq_0 = *graph.GetNode(parent_node_0->Index());
|
||||
Node* dq_1 = nullptr;
|
||||
const ONNX_NAMESPACE::TensorProto* weight_proto = nullptr;
|
||||
if (parent_node_1 && parent_node_1->OpType() == QDQ::DQOpName) {
|
||||
dq_1 = graph.GetNode(parent_node_1->Index());
|
||||
} else if (!graph_utils::IsInitializer(graph, weight_arg->Name(), true) ||
|
||||
!graph.GetInitializedTensor(weight_arg->Name(), weight_proto) ||
|
||||
weight_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
// Support float32 weight initializer only for now.
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t bias_size = -1;
|
||||
if (bias_arg) {
|
||||
auto bias_shape = bias_arg->Shape();
|
||||
if (!graph_utils::IsInitializer(graph, bias_arg->Name(), true) || !bias_shape || bias_shape->dim_size() != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bias_size = bias_shape->dim(0).dim_value();
|
||||
}
|
||||
|
||||
// Nothing to do if neither weight nor bias is initializer.
|
||||
if (dq_1 && bias_size == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool is_per_tensor_scale = true;
|
||||
// If weight is quantized, it's either per-tensor or per-channel on specific axis (0 for Conv and 1 for Gemm).
|
||||
if (dq_1 && !optimizer_utils::IsScalar(*dq_1->InputDefs()[1])) {
|
||||
is_per_tensor_scale = false;
|
||||
auto weight_scale_shape = dq_1->InputDefs()[1]->Shape();
|
||||
if (!weight_scale_shape || weight_scale_shape->dim_size() != 1 || !weight_scale_shape->dim(0).has_dim_value() ||
|
||||
weight_scale_shape->dim(0).dim_value() != bias_size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& dq_attrs = dq_1->GetAttributes();
|
||||
if (dq_attrs.find("block_size") != dq_attrs.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int64_t axis = 1;
|
||||
if (auto axis_iter = dq_attrs.find("axis"); axis_iter != dq_attrs.end()) {
|
||||
axis = axis_iter->second.i();
|
||||
}
|
||||
|
||||
int64_t expected_axis = 0;
|
||||
if (node.OpType() == "Gemm") {
|
||||
int64_t transB = 0;
|
||||
const auto& gemm_attrs = node.GetAttributes();
|
||||
if (auto trans_b_iter = gemm_attrs.find("transB"); trans_b_iter != gemm_attrs.end()) {
|
||||
transB = trans_b_iter->second.i();
|
||||
}
|
||||
expected_axis = transB == 0 ? 1 : 0;
|
||||
}
|
||||
|
||||
if (axis != expected_axis) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
NodeArg* weight_scale_arg = nullptr;
|
||||
if (!dq_1) {
|
||||
auto initializer = std::make_unique<Initializer>(*weight_proto, graph.ModelPath());
|
||||
const float* weight_data = initializer->data<float>();
|
||||
|
||||
// Quantize float32 weight to int8_t (per-tensor, symmetric).
|
||||
// int8_t quantization of input[1] works with input[0] of all types.
|
||||
float scale;
|
||||
int8_t zp;
|
||||
GetQuantizationParameter(weight_data, static_cast<int64_t>(initializer->size()), scale, zp, nullptr);
|
||||
|
||||
// Weight scale initializer.
|
||||
ONNX_NAMESPACE::TensorProto weight_scale_proto;
|
||||
weight_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_scale"));
|
||||
weight_scale_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
|
||||
weight_scale_proto.mutable_float_data()->Add(scale);
|
||||
weight_scale_arg = &graph_utils::AddInitializer(graph, weight_scale_proto);
|
||||
|
||||
// Weight zero point initializer.
|
||||
ONNX_NAMESPACE::TensorProto weight_zp_proto;
|
||||
weight_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_weight_zp"));
|
||||
weight_zp_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8);
|
||||
weight_zp_proto.mutable_int32_data()->Add(static_cast<int32_t>(zp));
|
||||
NodeArg& weight_zp_arg = graph_utils::AddInitializer(graph, weight_zp_proto);
|
||||
|
||||
// Q from float32 to int8.
|
||||
ONNX_NAMESPACE::TypeProto weight_q_type_proto;
|
||||
weight_q_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT8);
|
||||
*weight_q_type_proto.mutable_tensor_type()->mutable_shape() = *weight_arg->Shape();
|
||||
NodeArg& weight_q_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_weight_q"), &weight_q_type_proto);
|
||||
Node& weight_q_node = graph.AddNode(
|
||||
graph.GenerateNodeArgName(node.Name() + "_weight_q"), QDQ::QOpName, "Weight Q node",
|
||||
{node.MutableInputDefs()[1], weight_scale_arg, &weight_zp_arg}, {&weight_q_arg}, nullptr, node.Domain());
|
||||
|
||||
// DQ from int8 to float32.
|
||||
NodeArg& weight_dq_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_weight_dq"), weight_arg->TypeAsProto());
|
||||
Node& weight_dq_node =
|
||||
graph.AddNode(graph.GenerateNodeArgName(node.Name() + "_weight_dq"), QDQ::DQOpName, "Weight DQ node",
|
||||
{&weight_q_arg, weight_scale_arg, &weight_zp_arg}, {&weight_dq_arg}, nullptr, node.Domain());
|
||||
graph.AddEdge(weight_q_node.Index(), weight_dq_node.Index(), 0, 0);
|
||||
node.MutableInputDefs()[1] = &weight_dq_arg;
|
||||
graph.AddEdge(weight_dq_node.Index(), node.Index(), 0, 1);
|
||||
} else {
|
||||
weight_scale_arg = dq_1->MutableInputDefs()[1];
|
||||
}
|
||||
|
||||
if (bias_size != -1) {
|
||||
// Bias is quantized to int32. Q cannot support int32 as target type, need to compose the whole computation.
|
||||
// bias_scale = input_scale * weight_scale.
|
||||
NodeArg& bias_scale_arg = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_scale"),
|
||||
weight_scale_arg->TypeAsProto());
|
||||
Node& mul_node =
|
||||
graph.AddNode(graph.GenerateNodeName(node.Name() + "_scale"), "Mul", "Bias scale node",
|
||||
{dq_0.MutableInputDefs()[1], weight_scale_arg}, {&bias_scale_arg}, nullptr, node.Domain());
|
||||
|
||||
// fp_bias / scale.
|
||||
NodeArg& bias_div_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div"), bias_arg->TypeAsProto());
|
||||
Node& div_node =
|
||||
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div"), "Div", "Bias div node",
|
||||
{node.MutableInputDefs()[2], &bias_scale_arg}, {&bias_div_arg}, nullptr, node.Domain());
|
||||
graph.AddEdge(mul_node.Index(), div_node.Index(), 0, 1);
|
||||
|
||||
// Round(fp_bias / scale).
|
||||
NodeArg& bias_div_round_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_div_round"), bias_arg->TypeAsProto());
|
||||
Node& round_node =
|
||||
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_div_round"), "Round", "Bias div round node",
|
||||
{&bias_div_arg}, {&bias_div_round_arg}, nullptr, node.Domain());
|
||||
graph.AddEdge(div_node.Index(), round_node.Index(), 0, 0);
|
||||
|
||||
// Cast(Round(fp_bias / scale)) to int32.
|
||||
ONNX_NAMESPACE::TypeProto bias_int32_type_proto;
|
||||
bias_int32_type_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
*bias_int32_type_proto.mutable_tensor_type()->mutable_shape() = *bias_arg->Shape();
|
||||
NodeArg& bias_int32_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_int32"), &bias_int32_type_proto);
|
||||
Node& cast_node = graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_int32"), "Cast", "Bias INT32 node",
|
||||
{&bias_div_round_arg}, {&bias_int32_arg}, nullptr, node.Domain());
|
||||
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_INT32));
|
||||
graph.AddEdge(round_node.Index(), cast_node.Index(), 0, 0);
|
||||
|
||||
// Bias DQ node produces output to Conv/Gemm node's input_2, with scale = input_scale_0 * input_scale_1, zp = 0.
|
||||
NodeArg& bias_dq_arg =
|
||||
graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_bias_dq"), bias_arg->TypeAsProto());
|
||||
Node& bias_dq_node =
|
||||
graph.AddNode(graph.GenerateNodeName(node.Name() + "_bias_dq"), QDQ::DQOpName, "Bias DQ node",
|
||||
{&bias_int32_arg, &bias_scale_arg}, {&bias_dq_arg}, nullptr, node.Domain());
|
||||
if (!is_per_tensor_scale) {
|
||||
bias_dq_node.AddAttribute("axis", static_cast<int64_t>(0));
|
||||
}
|
||||
|
||||
graph.AddEdge(cast_node.Index(), bias_dq_node.Index(), 0, 0);
|
||||
graph.AddEdge(mul_node.Index(), bias_dq_node.Index(), 0, 1);
|
||||
node.MutableInputDefs()[2] = &bias_dq_arg;
|
||||
graph.AddEdge(bias_dq_node.Index(), node.Index(), 0, 2);
|
||||
}
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
* @class WeightBiasQuantization
|
||||
*
|
||||
* Some quantized models do not have Gemm/Conv/ConvTranspose's weight and/or bias quantized. This optimization adds
|
||||
* subgraphs with Q->DQ after weight and/or bias. It's possible that the ConstantFolding optimizer would fold the Q Op
|
||||
* so that weight and/or bias initializers are folded to initializers in target data types, followed by DQ Op.
|
||||
* For weight, the Q output is a symmetric per-tensor INT8 tensor.
|
||||
* For bias, the Q's scale = scale_input_0 * scale_input_1 and zero_point = (INT32)0.
|
||||
*/
|
||||
class WeightBiasQuantization : public GraphTransformer {
|
||||
public:
|
||||
WeightBiasQuantization() noexcept : GraphTransformer("WeightBiasQuantization") {}
|
||||
|
||||
private:
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -11,7 +11,7 @@
|
|||
#include "core/graph/onnx_protobuf.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/optimizer/double_qdq_pairs_remover.h"
|
||||
#include "core/optimizer/qdq_transformer/bias_quantization.h"
|
||||
#include "core/optimizer/qdq_transformer/weight_bias_quantization.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
|
||||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
|
||||
|
|
@ -4848,7 +4848,7 @@ TEST(QDQTransformerTests, DropDQSelectorWithDQProducingGraphOutput) {
|
|||
}
|
||||
#endif // !defined(DISABLE_CONTRIB_OPS)
|
||||
|
||||
TEST(QDQTransformerTests, BiasQuantization_Conv) {
|
||||
TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Bias) {
|
||||
auto test_case = [](bool use_contrib_qdq) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
NodeArg* input_arg = builder.MakeInput<uint8_t>({1, 24, 128, 128}, std::numeric_limits<uint8_t>::min(),
|
||||
|
|
@ -4896,7 +4896,93 @@ TEST(QDQTransformerTests, BiasQuantization_Conv) {
|
|||
#endif
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, BiasQuantization_Gemm) {
|
||||
TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Weight_Bias) {
|
||||
auto test_case = [](bool use_contrib_qdq) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
NodeArg* input_arg = builder.MakeInput<uint8_t>({1, 24, 67, 67}, std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
NodeArg* weight_arg = builder.MakeInitializer<float>({24, 1, 5, 5}, -0.1f, 0.1f);
|
||||
NodeArg* bias_arg = builder.MakeInitializer<float>({24}, -0.1f, 0.1f);
|
||||
NodeArg* input_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* conv_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* output_arg = builder.MakeOutput();
|
||||
|
||||
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.014f, static_cast<uint8_t>(127), input_dq_arg,
|
||||
use_contrib_qdq);
|
||||
auto& conv_node = builder.AddNode("Conv", {input_dq_arg, weight_arg, bias_arg}, {conv_dq_arg});
|
||||
conv_node.AddAttribute("dilations", std::vector<int64_t>{1, 1});
|
||||
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{5, 5});
|
||||
conv_node.AddAttribute("strides", std::vector<int64_t>{2, 2});
|
||||
conv_node.AddAttribute("group", static_cast<int64_t>(24));
|
||||
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
|
||||
builder.AddQuantizeLinearNode<uint8_t>(conv_dq_arg, 0.014f, static_cast<uint8_t>(127), output_arg,
|
||||
use_contrib_qdq);
|
||||
};
|
||||
|
||||
auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
|
||||
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
|
||||
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
|
||||
EXPECT_EQ(op_to_count["QLinearConv"], 1);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
|
||||
};
|
||||
|
||||
test_case(false);
|
||||
#if !defined(DISABLE_CONTRIB_OPS)
|
||||
test_case(true);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, WeightBiasQuantization_ConvTranspose_Weight) {
|
||||
auto test_case = [](bool use_contrib_qdq) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
NodeArg* input_arg = builder.MakeInput<uint8_t>({1, 3, 4, 4}, std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
NodeArg* weight_arg = builder.MakeInitializer<float>({3, 3, 3, 3}, -0.1f, 0.1f);
|
||||
NodeArg* input_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* conv_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* output_arg = builder.MakeOutput();
|
||||
|
||||
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.014f, static_cast<uint8_t>(127), input_dq_arg,
|
||||
use_contrib_qdq);
|
||||
auto& conv_node = builder.AddNode("ConvTranspose", {input_dq_arg, weight_arg}, {conv_dq_arg});
|
||||
conv_node.AddAttribute("dilations", std::vector<int64_t>{1, 1});
|
||||
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{3, 3});
|
||||
conv_node.AddAttribute("strides", std::vector<int64_t>{1, 1});
|
||||
conv_node.AddAttribute("group", static_cast<int64_t>(1));
|
||||
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});
|
||||
builder.AddQuantizeLinearNode<uint8_t>(conv_dq_arg, 0.014f, static_cast<uint8_t>(127), output_arg,
|
||||
use_contrib_qdq);
|
||||
};
|
||||
|
||||
auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
|
||||
// No QLinearConvTranspose CPU kernel. Check the graph only.
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
|
||||
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"] + op_to_count["com.microsoft.DequantizeLinear"], 2);
|
||||
EXPECT_EQ(op_to_count["ConvTranspose"], 1);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
|
||||
};
|
||||
|
||||
test_case(false);
|
||||
#if !defined(DISABLE_CONTRIB_OPS)
|
||||
test_case(true);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if !defined(DISABLE_CONTRIB_OPS)
|
||||
|
||||
TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Bias) {
|
||||
auto test_case = [](bool use_contrib_qdq) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
NodeArg* input_arg =
|
||||
|
|
@ -4933,10 +5019,80 @@ TEST(QDQTransformerTests, BiasQuantization_Gemm) {
|
|||
};
|
||||
|
||||
test_case(false);
|
||||
#if !defined(DISABLE_CONTRIB_OPS)
|
||||
test_case(true);
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight) {
|
||||
auto test_case = [](bool use_contrib_qdq) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
NodeArg* input_arg =
|
||||
builder.MakeInput<uint8_t>({1, 32}, std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
|
||||
NodeArg* weight_arg = builder.MakeInitializer<float>({32, 16}, -0.1f, 0.1f);
|
||||
NodeArg* input_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* gemm_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* output_arg = builder.MakeOutput();
|
||||
|
||||
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.001f, static_cast<uint8_t>(127), input_dq_arg,
|
||||
use_contrib_qdq);
|
||||
builder.AddNode("Gemm", {input_dq_arg, weight_arg}, {gemm_dq_arg});
|
||||
builder.AddQuantizeLinearNode<uint8_t>(gemm_dq_arg, 0.144f, static_cast<uint8_t>(127), output_arg,
|
||||
use_contrib_qdq);
|
||||
};
|
||||
|
||||
auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
|
||||
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
|
||||
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
|
||||
};
|
||||
|
||||
test_case(false);
|
||||
test_case(true);
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, WeightBiasQuantization_Gemm_Weight_Bias) {
|
||||
auto test_case = [](bool use_contrib_qdq) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
NodeArg* input_arg =
|
||||
builder.MakeInput<uint8_t>({1, 32}, std::numeric_limits<uint8_t>::min(), std::numeric_limits<uint8_t>::max());
|
||||
NodeArg* weight_arg = builder.MakeInitializer<float>({16, 32}, -0.1f, 0.1f);
|
||||
NodeArg* bias_arg = builder.MakeInitializer<float>({16}, -0.1f, 0.1f);
|
||||
NodeArg* input_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* gemm_dq_arg = builder.MakeIntermediate();
|
||||
NodeArg* output_arg = builder.MakeOutput();
|
||||
|
||||
builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.001f, static_cast<uint8_t>(127), input_dq_arg,
|
||||
use_contrib_qdq);
|
||||
auto& gemm_node = builder.AddNode("Gemm", {input_dq_arg, weight_arg, bias_arg}, {gemm_dq_arg});
|
||||
gemm_node.AddAttribute("transB", static_cast<int64_t>(1));
|
||||
builder.AddQuantizeLinearNode<uint8_t>(gemm_dq_arg, 0.144f, static_cast<uint8_t>(127), output_arg,
|
||||
use_contrib_qdq);
|
||||
};
|
||||
|
||||
auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
|
||||
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
|
||||
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
|
||||
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 18);
|
||||
|
||||
TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, 19);
|
||||
};
|
||||
|
||||
test_case(false);
|
||||
test_case(true);
|
||||
}
|
||||
|
||||
#endif // !defined(DISABLE_CONTRIB_OPS)
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue