Propagate QDQ only when scale and zp are scalar (#7492)

fix crash when DeQuantizeLinear's output is graph output
propagate only when scale and zp are scalar.
fix bug for is_modified= is_modified || TryCancelOutDQQPair(graph, dq_node, q_node); in which TryCancelOutDQQPair wouldn't be invoked if is_modified is true
This commit is contained in:
Yufeng Li 2021-04-29 14:40:41 -07:00 committed by GitHub
parent e255506bcd
commit d337fa90e7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 201 additions and 53 deletions

View file

@ -13,7 +13,7 @@ namespace onnxruntime {
static constexpr size_t QDQInputCountRequired = 3;
static constexpr size_t QDQInputScaleIdx = 1;
static constexpr size_t QDQInputZeroPointIdex = 2;
static constexpr size_t QDQInputZeroPointIdx = 2;
static bool CanNodePropagate(const Node& node) {
return graph_utils::IsSupportedOptypeVersionAndDomain(node, "MaxPool", {12}) ||
@ -26,22 +26,26 @@ static bool TryCancelOutDQQPair(Graph& graph, Node& dq_node, Node& q_node) {
std::vector<NodeArg*>& q_input_defs = q_node.MutableInputDefs();
if (dq_input_defs.size() != QDQInputCountRequired ||
q_input_defs.size() != QDQInputCountRequired ||
!optimizer_utils::IsScalar(*q_input_defs[QDQInputZeroPointIdx]) ||
!optimizer_utils::IsScalar(*q_input_defs[QDQInputScaleIdx]) ||
!optimizer_utils::IsScalar(*dq_input_defs[QDQInputZeroPointIdx]) ||
!optimizer_utils::IsScalar(*dq_input_defs[QDQInputScaleIdx]) ||
!graph.GetNodeOutputsInGraphOutputs(q_node).empty()) {
return false;
}
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = nullptr;
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = nullptr;
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = nullptr;
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = nullptr;
if (!graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputScaleIdx]) ||
!graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputScaleIdx]) ||
!graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputZeroPointIdex]) ||
!graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputZeroPointIdex]) ||
!graph.GetInitializedTensor(dq_input_defs[QDQInputScaleIdx]->Name(), dq_scale_tensor_proto) ||
!graph.GetInitializedTensor(q_input_defs[QDQInputScaleIdx]->Name(), q_scale_tensor_proto) ||
!graph.GetInitializedTensor(dq_input_defs[QDQInputZeroPointIdex]->Name(), dq_zp_tensor_proto) ||
!graph.GetInitializedTensor(q_input_defs[QDQInputZeroPointIdex]->Name(), q_zp_tensor_proto)) {
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto =
graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputScaleIdx]->Name());
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto =
graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputScaleIdx]->Name());
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto =
graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputZeroPointIdx]->Name());
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto =
graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputZeroPointIdx]->Name());
if (nullptr == q_zp_tensor_proto ||
nullptr == dq_zp_tensor_proto ||
nullptr == q_scale_tensor_proto ||
nullptr == dq_scale_tensor_proto) {
return false;
}
@ -162,32 +166,35 @@ bool QDQPropagationTransformer::PropagateDQForward(Graph& graph) const {
continue;
}
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = nullptr;
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = nullptr;
if (!graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputScaleIdx]) ||
!graph_utils::NodeArgIsConstant(graph, *dq_input_defs[QDQInputZeroPointIdex]) ||
!graph.GetInitializedTensor(dq_input_defs[QDQInputScaleIdx]->Name(), dq_scale_tensor_proto) ||
!graph.GetInitializedTensor(dq_input_defs[QDQInputZeroPointIdex]->Name(), dq_zp_tensor_proto)) {
if (!optimizer_utils::IsScalar(*dq_input_defs[QDQInputZeroPointIdx]) ||
!optimizer_utils::IsScalar(*dq_input_defs[QDQInputScaleIdx])) {
continue;
}
const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto =
graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputZeroPointIdx]->Name());
const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto =
graph_utils::GetConstantInitializer(graph, dq_input_defs[QDQInputScaleIdx]->Name());
if (nullptr == dq_zp_tensor_proto || nullptr == dq_scale_tensor_proto) {
continue;
}
do {
Node& next_node = *graph.GetNode(dq_node.OutputNodesBegin()->Index());
if (!CanNodePropagate(next_node)) {
// Try canceling out DQ/Q pair
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "QuantizeLinear", {10, 13}) &&
graph_utils::IsSupportedProvider(next_node, GetCompatibleExecutionProviders()) &&
TryCancelOutDQQPair(graph, dq_node, next_node)) {
is_modified = true;
}
break;
}
SwapAdjacentNodes(graph, dq_node, next_node);
is_modified = true;
} while (optimizer_utils::CheckOutputEdges(graph, dq_node, 1));
// Cancel out DQ/Q pair
Node& q_node = *graph.GetNode(dq_node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(q_node, "QuantizeLinear", {10, 13}) ||
!graph_utils::IsSupportedProvider(q_node, GetCompatibleExecutionProviders())) {
continue;
}
is_modified = is_modified || TryCancelOutDQQPair(graph, dq_node, q_node);
}
return is_modified;
@ -212,16 +219,18 @@ bool QDQPropagationTransformer::PropagateQBackward(Graph& graph) const {
}
std::vector<NodeArg*>& q_input_defs = q_node.MutableInputDefs();
if (q_input_defs.size() != QDQInputCountRequired) {
if (q_input_defs.size() != QDQInputCountRequired ||
!optimizer_utils::IsScalar(*q_input_defs[QDQInputZeroPointIdx]) ||
!optimizer_utils::IsScalar(*q_input_defs[QDQInputScaleIdx])) {
continue;
}
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = nullptr;
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = nullptr;
if (!graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputScaleIdx]) ||
!graph_utils::NodeArgIsConstant(graph, *q_input_defs[QDQInputZeroPointIdex]) ||
!graph.GetInitializedTensor(q_input_defs[QDQInputScaleIdx]->Name(), q_scale_tensor_proto) ||
!graph.GetInitializedTensor(q_input_defs[QDQInputZeroPointIdex]->Name(), q_zp_tensor_proto)) {
const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto =
graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputZeroPointIdx]->Name());
const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto =
graph_utils::GetConstantInitializer(graph, q_input_defs[QDQInputScaleIdx]->Name());
if (nullptr == q_zp_tensor_proto || nullptr == q_scale_tensor_proto) {
continue;
}
@ -230,27 +239,21 @@ bool QDQPropagationTransformer::PropagateQBackward(Graph& graph) const {
break;
}
Node& prev_node = *graph.GetNode(q_node.InputNodesBegin()->Index());
if (!optimizer_utils::CheckOutputEdges(graph, prev_node, 1) ||
!CanNodePropagate(prev_node)) {
if (!optimizer_utils::CheckOutputEdges(graph, prev_node, 1)) break;
if (!CanNodePropagate(prev_node)) {
// Try canceling out DQ/Q pair
Node& dq_node = prev_node;
if (graph_utils::IsSupportedOptypeVersionAndDomain(dq_node, "DequantizeLinear", {10, 13}) &&
graph_utils::IsSupportedProvider(dq_node, GetCompatibleExecutionProviders()) &&
TryCancelOutDQQPair(graph, dq_node, q_node)) {
is_modified = true;
}
break;
}
SwapAdjacentNodes(graph, prev_node, q_node);
is_modified = true;
} while (true);
// Cancel out DQ/Q pair
if (q_node.InputNodesBegin() == q_node.InputNodesEnd()) {
continue;
}
Node& dq_node = *graph.GetNode(q_node.InputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(dq_node, "DequantizeLinear", {10, 13}) ||
!graph_utils::IsSupportedProvider(dq_node, GetCompatibleExecutionProviders()) ||
!optimizer_utils::CheckOutputEdges(graph, dq_node, 1)) {
continue;
}
is_modified = is_modified || TryCancelOutDQQPair(graph, dq_node, q_node);
}
return is_modified;

View file

@ -21,10 +21,12 @@ namespace onnxruntime {
namespace optimizer_utils {
bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
return tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
return tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ||
tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
}
inline bool IsScalar(const NodeArg& input_arg) {
bool IsScalar(const NodeArg& input_arg) {
auto shape = input_arg.Shape();
if (shape == nullptr) {
// shape inferencing wasn't able to populate shape information for this NodeArg

View file

@ -15,6 +15,9 @@ namespace optimizer_utils {
// Check if TensorProto contains a floating point type.
bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto);
// Check if NodeArg takes in a scalar tensor.
bool IsScalar(const NodeArg& input_arg);
/** Check whether a input is initializer with specified float value.
@param expected_value is the expected value of the initializer.
@param is_constant means whether the initializer is required to be constant.

View file

@ -881,8 +881,148 @@ TEST(QDQTransformerTests, QDQPropagation_QDQ_CancelOut_More) {
test_case({1, 13, 13, 23}, true, true);
}
TEST(QDQTransformerTests, QDQPropagation_Q_No_Parent) {
auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& perms) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>(input_shape, -1.f, 1.f);
auto* output_arg = builder.MakeOutput();
// add transpose
auto* transpose_output = builder.MakeIntermediate();
Node& transpose_node = builder.AddNode("Transpose", {input_arg}, {transpose_output});
transpose_node.AddAttribute("perm", perms);
// add Q
builder.AddQuantizeLinearNode<uint8_t>(transpose_output, .0035f, 135, output_arg);
};
auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) {
GraphViewer graph_viewer(session.GetGraph());
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[0])->OpType(), "QuantizeLinear");
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[1])->OpType(), "Transpose");
};
TransformerTester(build_test_case,
check_mp_reshape_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
12 /*opset_version*/);
};
// Test the basic case of a single 1D/2D/3D convolution.
test_case({1, 13, 13, 23}, {0, 2, 3, 1});
}
TEST(QDQTransformerTests, QDQPropagation_DQ_No_Children) {
auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& perms) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<uint8_t>(input_shape,
std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
auto* output_arg = builder.MakeOutput();
// add DQ
auto* dq_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<uint8_t>(input_arg, .0035f, 135, dq_output);
// add transpose
Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {output_arg});
transpose_node.AddAttribute("perm", perms);
};
auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
GraphViewer graph_viewer(session.GetGraph());
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[0])->OpType(), "Transpose");
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[1])->OpType(), "DequantizeLinear");
};
TransformerTester(build_test_case,
check_mp_reshape_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
12 /*opset_version*/);
};
// Test the basic case of a single 1D/2D/3D convolution.
test_case({1, 13, 13, 23}, {0, 2, 3, 1});
}
TEST(QDQTransformerTests, QDQPropagation_Per_Layer_No_Propagation) {
auto test_case = [&](const std::vector<int64_t>& input_shape, const std::vector<int64_t>& perms) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<uint8_t>(input_shape,
std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
auto* output_arg = builder.MakeOutput();
// add DQ
auto* dq_output = builder.MakeIntermediate();
auto* dq_scale = builder.Make1DInitializer(std::vector<float>(input_shape[1], 0.0035f));
auto* dq_zp = builder.Make1DInitializer(std::vector<uint8_t>(input_shape[1], 135));
builder.AddNode("DequantizeLinear", {input_arg, dq_scale, dq_zp}, {dq_output});
// add transpose
Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {output_arg});
transpose_node.AddAttribute("perm", perms);
};
auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
GraphViewer graph_viewer(session.GetGraph());
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[0])->OpType(), "DequantizeLinear");
EXPECT_EQ(graph_viewer.GetNode(node_topology_list[1])->OpType(), "Transpose");
};
TransformerTester(build_test_case,
check_mp_reshape_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
12 /*opset_version*/);
};
// Test the basic case of a single 1D/2D/3D convolution.
test_case({1, 13, 13, 23}, {0, 2, 3, 1});
}
TEST(QDQTransformerTests, QDQPropagation_DQ_Q) {
auto test_case = [&](const std::vector<int64_t>& input_shape) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<uint8_t>(input_shape,
std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
auto* output_arg = builder.MakeOutput();
// add DQ
auto* dq_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<uint8_t>(input_arg, .0035f, 135, dq_output);
// add Q
builder.AddQuantizeLinearNode<uint8_t>(dq_output, .0035f, 135, output_arg);
};
auto check_mp_reshape_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
EXPECT_EQ(op_to_count["QuantizeLinear"], 1);
EXPECT_EQ(op_to_count["DequantizeLinear"], 1);
};
TransformerTester(build_test_case,
check_mp_reshape_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
12 /*opset_version*/);
};
// Test the basic case of a single 1D/2D/3D convolution.
test_case({1, 13, 13, 23});
}
TEST(QDQTransformerTests, Concat_UInt8) {
auto test_case = [&](const std::vector<std::vector<int64_t>>& input_shapes, int64_t axis, bool can_trans=true) {
auto test_case = [&](const std::vector<std::vector<int64_t>>& input_shapes, int64_t axis, bool can_trans = true) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto input_count = input_shapes.size();
std::vector<NodeArg*> input_args;