mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Propagate QDQ only when scale and zp are scalar (#7492)
fix crash when DeQuantizeLinear's output is graph output propagate only when scale and zp are scalar. fix bug for is_modified= is_modified || TryCancelOutDQQPair(graph, dq_node, q_node); in which TryCancelOutDQQPair wouldn't be invoked if is_modified is true
This commit is contained in:
parent
e255506bcd
commit
d337fa90e7
4 changed files with 201 additions and 53 deletions
|
|
@ -13,7 +13,7 @@ namespace onnxruntime {
|
|||
|
||||
static constexpr size_t QDQInputCountRequired = 3;
|
||||
static constexpr size_t QDQInputScaleIdx = 1;
|
||||
static constexpr size_t QDQInputZeroPointIdex = 2;
|
||||
static constexpr size_t QDQInputZeroPointIdx = 2;
|
||||
|
||||
static bool CanNodePropagate(const Node& node) {
|
||||
return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {12}) ||
|
||||
|
|
@ -26,22 +26,26 @@ static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) {
|
|||
std::vector<NodeArg*>& q_input_defs = q_node.MutableInputDefs();
|
||||
if (dq_input_defs.size() != QDQInputCountRequired ||
|
||||
q_input_defs.size() != QDQInputCountRequired ||
|
||||
!optimizer_utils::IsScalar(*q_input_defs[QDQInputZeroPointIdx]) ||
|
||||
!optimizer_utils::IsScalar(*q_input_defs[QDQInputScaleIdx]) ||
|
||||
!optimizer_utils::IsScalar(*dq_input_defs[QDQInputZeroPointIdx]) ||
|
||||
!optimizer_utils::IsScalar(*dq_input_defs[QDQInputScaleIdx]) ||
|
||||
!graph.GetNodeOutputsInGraphOutputs(q_node).empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = nullptr;
|
||||
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = nullptr;
|
||||
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = nullptr;
|
||||
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = nullptr;
|
||||
if (!graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputScaleIdx]) ||
|
||||
!graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputScaleIdx]) ||
|
||||
!graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputZeroPointIdex]) ||
|
||||
!graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputZeroPointIdex]) ||
|
||||
!graph.GetInitializedTensor(dq_input_defs[QDQInputScaleIdx]->Name(), dq_scale_tensor_proto) ||
|
||||
!graph.GetInitializedTensor(q_input_defs[QDQInputScaleIdx]->Name(), q_scale_tensor_proto) ||
|
||||
!graph.GetInitializedTensor(dq_input_defs[QDQInputZeroPointIdex]->Name(), dq_zp_tensor_proto) ||
|
||||
!graph.GetInitializedTensor(q_input_defs[QDQInputZeroPointIdex]->Name(), q_zp_tensor_proto)) {
|
||||
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputScaleIdx]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputScaleIdx]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputZeroPointIdx]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputZeroPointIdx]->Name());
|
||||
if (nullptr == q_zp_tensor_proto ||
|
||||
nullptr == dq_zp_tensor_proto ||
|
||||
nullptr == q_scale_tensor_proto ||
|
||||
nullptr == dq_scale_tensor_proto) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -162,32 +166,35 @@ bool QDQPropagationTransformer::PropagateDQForward(Graph& graph) const {
|
|||
continue;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = nullptr;
|
||||
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = nullptr;
|
||||
if (!graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputScaleIdx]) ||
|
||||
!graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputZeroPointIdex]) ||
|
||||
!graph.GetInitializedTensor(dq_input_defs[QDQInputScaleIdx]->Name(), dq_scale_tensor_proto) ||
|
||||
!graph.GetInitializedTensor(dq_input_defs[QDQInputZeroPointIdex]->Name(), dq_zp_tensor_proto)) {
|
||||
if (!optimizer_utils::IsScalar(*dq_input_defs[QDQInputZeroPointIdx]) ||
|
||||
!optimizer_utils::IsScalar(*dq_input_defs[QDQInputScaleIdx])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputZeroPointIdx]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputScaleIdx]->Name());
|
||||
|
||||
if (nullptr == dq_zp_tensor_proto || nullptr == dq_scale_tensor_proto) {
|
||||
continue;
|
||||
}
|
||||
|
||||
do {
|
||||
Node& next_node = *graph.GetNode(dq_node.OutputNodesBegin()->Index());
|
||||
if (!CanNodePropagate(next_node)) {
|
||||
// Try canceling out DQ/Q pair
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "QuantizeLinear", {10, 13}) &&
|
||||
graph_utils::IsSupportedProvider(next_node, GetCompatibleExecutionProviders()) &&
|
||||
TryCancelOutDQQPair(graph, dq_node, next_node)) {
|
||||
is_modified = true;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
SwapAdjacentNodes(graph, dq_node, next_node);
|
||||
is_modified = true;
|
||||
} while (optimizer_utils::CheckOutputEdges(graph, dq_node, 1));
|
||||
|
||||
// Cancel out DQ/Q pair
|
||||
Node& q_node = *graph.GetNode(dq_node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(q_node, "QuantizeLinear", {10, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(q_node, GetCompatibleExecutionProviders())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
is_modified = is_modified || TryCancelOutDQQPair(graph, dq_node, q_node);
|
||||
}
|
||||
|
||||
return is_modified;
|
||||
|
|
@ -212,16 +219,18 @@ bool QDQPropagationTransformer::PropagateQBackward(Graph& graph) const {
|
|||
}
|
||||
|
||||
std::vector<NodeArg*>& q_input_defs = q_node.MutableInputDefs();
|
||||
if (q_input_defs.size() != QDQInputCountRequired) {
|
||||
if (q_input_defs.size() != QDQInputCountRequired ||
|
||||
!optimizer_utils::IsScalar(*q_input_defs[QDQInputZeroPointIdx]) ||
|
||||
!optimizer_utils::IsScalar(*q_input_defs[QDQInputScaleIdx])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = nullptr;
|
||||
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = nullptr;
|
||||
if (!graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputScaleIdx]) ||
|
||||
!graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputZeroPointIdex]) ||
|
||||
!graph.GetInitializedTensor(q_input_defs[QDQInputScaleIdx]->Name(), q_scale_tensor_proto) ||
|
||||
!graph.GetInitializedTensor(q_input_defs[QDQInputZeroPointIdex]->Name(), q_zp_tensor_proto)) {
|
||||
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputZeroPointIdx]->Name());
|
||||
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputScaleIdx]->Name());
|
||||
|
||||
if (nullptr == q_zp_tensor_proto || nullptr == q_scale_tensor_proto) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -230,27 +239,21 @@ bool QDQPropagationTransformer::PropagateQBackward(Graph& graph) const {
|
|||
break;
|
||||
}
|
||||
Node& prev_node = *graph.GetNode(q_node.InputNodesBegin()->Index());
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, prev_node, 1) ||
|
||||
!CanNodePropagate(prev_node)) {
|
||||
if (!optimizer_utils::CheckOutputEdges(graph, prev_node, 1)) break;
|
||||
if (!CanNodePropagate(prev_node)) {
|
||||
// Try canceling out DQ/Q pair
|
||||
Node& dq_node = prev_node;
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(dq_node, "DequantizeLinear", {10, 13}) &&
|
||||
graph_utils::IsSupportedProvider(dq_node, GetCompatibleExecutionProviders()) &&
|
||||
TryCancelOutDQQPair(graph, dq_node, q_node)) {
|
||||
is_modified = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
SwapAdjacentNodes(graph, prev_node, q_node);
|
||||
is_modified = true;
|
||||
} while (true);
|
||||
|
||||
// Cancel out DQ/Q pair
|
||||
if (q_node.InputNodesBegin() == q_node.InputNodesEnd()) {
|
||||
continue;
|
||||
}
|
||||
Node& dq_node = *graph.GetNode(q_node.InputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(dq_node, "DequantizeLinear", {10, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(dq_node, GetCompatibleExecutionProviders()) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, dq_node, 1)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
is_modified = is_modified || TryCancelOutDQQPair(graph, dq_node, q_node);
|
||||
}
|
||||
|
||||
return is_modified;
|
||||
|
|
|
|||
|
|
@ -21,10 +21,12 @@ namespace onnxruntime {
|
|||
namespace optimizer_utils {
|
||||
|
||||
bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
|
||||
return tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
|
||||
return tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
|
||||
tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ||
|
||||
tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
|
||||
}
|
||||
|
||||
inline bool IsScalar(const NodeArg& input_arg) {
|
||||
bool IsScalar(const NodeArg& input_arg) {
|
||||
auto shape = input_arg.Shape();
|
||||
if (shape == nullptr) {
|
||||
// shape inferencing wasn't able to populate shape information for this NodeArg
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ namespace optimizer_utils {
|
|||
// Check if TensorProto contains a floating point type.
|
||||
bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto);
|
||||
|
||||
// Check if NodeArg takes in a scalar tensor.
|
||||
bool IsScalar(const NodeArg& input_arg);
|
||||
|
||||
/** Check whether a input is initializer with specified float value.
|
||||
@param expected_value is the expected value of the initializer.
|
||||
@param is_constant means whether the initializer is required to be constant.
|
||||
|
|
|
|||
|
|
@ -881,8 +881,148 @@ TEST(QDQTransformerTests, QDQPropagation_QDQ_CancelOut_More) {
|
|||
test_case({1, 13, 13, 23}, true, true);
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, QDQPropagation_Q_No_Parent) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& perms) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
|
||||
// add transpose
|
||||
auto* transpose_output = builder.MakeIntermediate();
|
||||
Node& transpose_node = builder.AddNode("Transpose", {input_arg}, {transpose_output});
|
||||
transpose_node.AddAttribute("perm", perms);
|
||||
|
||||
// add Q
|
||||
builder.AddQuantizeLinearNode<uint8_t>(transpose_output, .0035f, 135, output_arg);
|
||||
};
|
||||
|
||||
auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) {
|
||||
GraphViewer graph_viewer(session.GetGraph());
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[0])->OpType(), "QuantizeLinear");
|
||||
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[1])->OpType(), "Transpose");
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_mp_reshape_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
12 /*opset_version*/);
|
||||
};
|
||||
|
||||
// Test the basic case of a single 1D/2D/3D convolution.
|
||||
test_case({1, 13, 13, 23}, {0, 2, 3, 1});
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, QDQPropagation_DQ_No_Children) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& perms) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<uint8_t>(input_shape,
|
||||
std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
|
||||
// add DQ
|
||||
auto* dq_output = builder.MakeIntermediate();
|
||||
builder.AddDequantizeLinearNode<uint8_t>(input_arg, .0035f, 135, dq_output);
|
||||
|
||||
// add transpose
|
||||
Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {output_arg});
|
||||
transpose_node.AddAttribute("perm", perms);
|
||||
};
|
||||
|
||||
auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
GraphViewer graph_viewer(session.GetGraph());
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[0])->OpType(), "Transpose");
|
||||
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[1])->OpType(), "DequantizeLinear");
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_mp_reshape_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
12 /*opset_version*/);
|
||||
};
|
||||
|
||||
// Test the basic case of a single 1D/2D/3D convolution.
|
||||
test_case({1, 13, 13, 23}, {0, 2, 3, 1});
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, QDQPropagation_Per_Layer_No_Propagation) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& perms) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<uint8_t>(input_shape,
|
||||
std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
|
||||
// add DQ
|
||||
auto* dq_output = builder.MakeIntermediate();
|
||||
auto* dq_scale = builder.Make1DInitializer(std::vector<float>(input_shape[1], 0.0035f));
|
||||
auto* dq_zp = builder.Make1DInitializer(std::vector<uint8_t>(input_shape[1], 135));
|
||||
builder.AddNode("DequantizeLinear", {input_arg, dq_scale, dq_zp}, {dq_output});
|
||||
|
||||
// add transpose
|
||||
Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {output_arg});
|
||||
transpose_node.AddAttribute("perm", perms);
|
||||
};
|
||||
|
||||
auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
GraphViewer graph_viewer(session.GetGraph());
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[0])->OpType(), "DequantizeLinear");
|
||||
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[1])->OpType(), "Transpose");
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_mp_reshape_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
12 /*opset_version*/);
|
||||
};
|
||||
|
||||
// Test the basic case of a single 1D/2D/3D convolution.
|
||||
test_case({1, 13, 13, 23}, {0, 2, 3, 1});
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, QDQPropagation_DQ_Q) {
|
||||
auto test_case = [&](const std::vector<int64_t>& input_shape) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<uint8_t>(input_shape,
|
||||
std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
auto* output_arg = builder.MakeOutput();
|
||||
|
||||
// add DQ
|
||||
auto* dq_output = builder.MakeIntermediate();
|
||||
builder.AddDequantizeLinearNode<uint8_t>(input_arg, .0035f, 135, dq_output);
|
||||
|
||||
// add Q
|
||||
builder.AddQuantizeLinearNode<uint8_t>(dq_output, .0035f, 135, output_arg);
|
||||
};
|
||||
|
||||
auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) {
|
||||
auto op_to_count = CountOpsInGraph(session.GetGraph());
|
||||
EXPECT_EQ(op_to_count["QuantizeLinear"], 1);
|
||||
EXPECT_EQ(op_to_count["DequantizeLinear"], 1);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_mp_reshape_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
12 /*opset_version*/);
|
||||
};
|
||||
|
||||
// Test the basic case of a single 1D/2D/3D convolution.
|
||||
test_case({1, 13, 13, 23});
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, Concat_UInt8) {
|
||||
auto test_case = [&](const std::vector<std::vector<int64_t>>& input_shapes, int64_t axis, bool can_trans=true) {
|
||||
auto test_case = [&](const std::vector<std::vector<int64_t>>& input_shapes, int64_t axis, bool can_trans = true) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto input_count = input_shapes.size();
|
||||
std::vector<NodeArg*> input_args;
|
||||
|
|
|
|||
Loading…
Reference in a new issue