[TransposeOptimizer] Fix axis for QuantizeLinear inserted after DQ (per-channel) -> Unsqueeze (#21793)

### Description
- Fix computation of axis for `QuantizeLinear` inserted after the
sequence `DQ (per-channel) -> Unsqueeze`. Example:
  - Original: `DQ (axis = 0) -> Unsqueeze (axes = [0, 1, 2]) -> Op`
- After QDQ fix-up: `DQ (axis = 0) -> Unsqueeze (axes = [0, 1, 2]) -> Q
(axis = 3) -> DQ (axis = 3) -> Op`
- Before this PR, the axis for the inserted Q/DQ ops was not correctly
set to 3 (left as 0).
- Fix normalization of negative axis values for `QuantizeLinear`
inserted after the sequence `DQ (per-channel) ->Transpose`
  - Existing code added the wrong rank value to normalize the DQ axis.

### Motivation and Context
Fix errors in handling of per-channel DQ in code that fixes QDQ
NodeUnits.
This commit is contained in:
Adrian Lizarraga 2024-08-20 16:26:02 -07:00 committed by GitHub
parent 28c252c77e
commit 6fbb0ae81a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 360 additions and 31 deletions

View file

@ -24,6 +24,12 @@ static constexpr bool IsOnnxDomain(std::string_view domain) {
return (domain == onnxruntime::kOnnxDomain) || (domain == onnxruntime::kOnnxDomainAlias);
}
// Returns true if the given tensor shape represents a Scalar value (rank == 0) or a tensor with shape (1,).
constexpr bool IsScalarOr1Element1DTensor(gsl::span<const int64_t> tensor_shape) {
const size_t rank = tensor_shape.size();
return (rank == 0) || ((rank == 1) && (tensor_shape[0] == 1));
}
static std::vector<int64_t> DataInt64(api::TensorRef& tensor) {
std::vector<uint8_t> raw_data = tensor.Data();
int64_t* data_int = reinterpret_cast<int64_t*>(raw_data.data());
@ -286,6 +292,13 @@ static std::unique_ptr<api::NodeRef> GetDQWithConstInitializerInputAndSingleCons
return result;
}
// Forward declarations for utils used by MakeQDQNodeUnit
static bool NormalizeAndValidateAxes(std::vector<int64_t>& axes, size_t rank);
static std::optional<std::vector<int64_t>> ReadFromAttrOrInput(const api::GraphRef& graph, api::NodeRef& node,
std::string_view attr_name, size_t inp_index,
int64_t opset);
static int64_t UnsqueezeAxis(gsl::span<const int64_t> sorted_positive_unsqueeze_axes, int64_t axis);
/// <summary>
/// Insert a Q -> DQ pair after the node following the DQ by using scale and zp info from the preceding DQ node.
/// DQ -> next node => DQ -> next node -> Q -> DQ.
@ -308,6 +321,7 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) {
const auto dq_domain = dq_node.Domain();
const auto& dq_inputs = dq_node.Inputs();
const bool is_transpose = next_node.OpType() == "Transpose";
const bool is_unsqueeze = next_node.OpType() == "Unsqueeze";
const auto scale_input = dq_inputs[1];
const auto scale_value_info = graph.GetValueInfo(scale_input);
@ -315,8 +329,8 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) {
std::optional<std::unique_ptr<api::ValueInfoRef>> zp_value_info;
auto scale_shape = scale_value_info->Shape();
if (!scale_shape && is_transpose) {
// axis potentially needs updating due to the transpose but we don't have the required info to do it.
if (!scale_shape) {
// axis potentially needs updating due to the transpose or unsqueeze but we don't have the required info to do it.
return false;
}
@ -325,19 +339,38 @@ static bool MakeQDQNodeUnit(api::GraphRef& graph, const api::NodeRef& dq_node) {
zp_value_info = graph.GetValueInfo(zp_input.value());
}
// per-axis quantization if not a scalar (shape is empty for scalar).
// note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization,
// so we have to check the shape.
auto update_dq_axis = scale_shape && !scale_shape->empty();
// DQ uses per-axis quantization if its scale input is not a scalar and not a tensor with shape (1,).
// Note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization,
// so we have to check the scale input's shape.
const bool update_dq_axis = !IsScalarOr1Element1DTensor(*scale_shape);
int64_t axis = dq_node.GetAttributeIntDefault("axis", 1);
// TODO(adrianlizarraga): Also need to update axis if Unsqueeze inserts a 1 before the axis dim.
if (update_dq_axis && is_transpose) {
// update axis.
auto perm = GetPermAttrIfValid(next_node);
assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid
NormalizeAndValidateAxis(axis, scale_shape->size());
axis = InvertPerm(*perm)[gsl::narrow_cast<size_t>(axis)];
if (update_dq_axis) {
const auto dq_input0_info = graph.GetValueInfo(dq_inputs[0]);
auto dq_input0_rank = dq_input0_info->ShapeRank();
if (!dq_input0_rank.has_value() || !NormalizeAndValidateAxis(axis, *dq_input0_rank)) {
return false; // Unable to normalize the DQ's axis.
}
if (is_transpose) {
auto perm = GetPermAttrIfValid(next_node);
assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid
axis = InvertPerm(*perm)[gsl::narrow_cast<size_t>(axis)];
} else if (is_unsqueeze) {
auto axes = ReadFromAttrOrInput(graph, next_node, "axes", /*inp_index*/ 1, /*opset*/ 13);
assert(axes.has_value()); // 'axes' are required for Unsqueeze
// Normalize negative unsqueeze axes by adding output rank.
// Unsqueeze output rank = input_rank + axes.size()
// Unsqueeze's input rank is the same as the DQ's input[0] rank.
if (!NormalizeAndValidateAxes(*axes, *dq_input0_rank + axes->size())) {
return false;
}
// Need to update axis if Unsqueeze inserts a 1 before the axis dim.
std::sort(axes->begin(), axes->end());
axis = UnsqueezeAxis(*axes, axis);
}
}
auto next_node_output_name = next_node.Outputs()[0];
@ -469,32 +502,67 @@ static bool NormalizeAndValidateAxes(std::vector<int64_t>& axes, size_t rank) {
for (size_t i = 0; i < axes.size(); ++i) {
if (axes[i] < 0) {
axes[i] += rank_int;
size_t x_size_t = gsl::narrow_cast<size_t>(axes[i]);
if (axes[i] < 0 || axes[i] >= rank_int || used_dims[x_size_t]) {
return false;
}
used_dims[x_size_t] = true;
}
size_t x_size_t = gsl::narrow_cast<size_t>(axes[i]);
if (axes[i] < 0 || axes[i] >= rank_int || used_dims[x_size_t]) {
return false;
}
used_dims[x_size_t] = true;
}
return true;
}
// Read constant int64 data from a node's input.
static std::optional<std::vector<int64_t>> ReadInt64sFromInput(const api::GraphRef& graph, api::NodeRef& node,
size_t inp_index) {
auto inputs = node.Inputs();
if (inp_index >= inputs.size() || inputs[inp_index] == "") {
return std::nullopt;
}
auto constant = graph.GetConstant(inputs[inp_index]);
if (constant == nullptr) {
return std::nullopt;
}
return DataInt64(*constant);
}
// Read int64 data from attribute or input, depending on whether model opset < provided opset
// Assumes that node is in the default ONNX domain.
static std::optional<std::vector<int64_t>> ReadFromAttrOrInput(OptimizerCtx& ctx, api::NodeRef& node,
std::string_view attr_name, size_t inp_index,
int64_t opset) {
assert(IsOnnxDomain(node.Domain())); // ctx.opset is only for Onnx domain.
if (ctx.opset < opset) {
return node.GetAttributeInts(attr_name);
} else {
auto inputs = node.Inputs();
if (inp_index >= inputs.size() || inputs[inp_index] == "") {
return std::nullopt;
return ReadInt64sFromInput(ctx.graph, node, inp_index);
}
}
// Read int64 data from attribute or input, depending on whether model opset < provided opset
static std::optional<std::vector<int64_t>> ReadFromAttrOrInput(const api::GraphRef& graph, api::NodeRef& node,
std::string_view attr_name, size_t input_index,
int64_t opset_with_input) {
std::optional<int64_t> actual_opset;
if (IsOnnxDomain(node.Domain())) {
actual_opset = graph.Opset(onnxruntime::kOnnxDomain);
if (!actual_opset.has_value()) {
actual_opset = graph.Opset(onnxruntime::kOnnxDomainAlias);
}
auto constant = ctx.graph.GetConstant(inputs[inp_index]);
if (constant == nullptr) {
return std::nullopt;
}
return DataInt64(*constant);
} else {
actual_opset = graph.Opset(node.Domain());
}
if (!actual_opset.has_value()) {
return std::nullopt;
}
if (*actual_opset < opset_with_input) {
return node.GetAttributeInts(attr_name);
} else {
return ReadInt64sFromInput(graph, node, input_index);
}
}
@ -685,6 +753,24 @@ static std::vector<int64_t> SqueezePerm(const std::vector<int64_t>& axes, const
return new_perm;
}
// Computes a new axis value for an unsqueezed version of a tensor. Incorrect if any axes
// values are negative, duplicated, or are not sorted in increasing order.
//
// Ex: axes = [0, 1, 2], axis = 0, new_axis = 3
// axes = [0, 1, 3], axis = 1, new_axis = 4
static int64_t UnsqueezeAxis(gsl::span<const int64_t> sorted_positive_unsqueeze_axes, int64_t axis) {
assert(axis >= 0);
int64_t new_axis = axis;
for (int64_t unsqueeze_axis : sorted_positive_unsqueeze_axes) {
if (unsqueeze_axis <= new_axis) {
new_axis += 1;
}
}
return new_axis;
}
// Computes a new axes attribute for an input that has been permuted using perm. Unsafe if axes/perm are invalid or
// have negative values.
//
@ -2662,16 +2748,22 @@ static bool TryFixTransposeMissingDQ(OptimizerCtx& ctx, api::NodeRef& transpose_
zp_value_info = ctx.graph.GetValueInfo(zp_input.value());
}
// Per-axis quantization if not a scalar (shape is empty for scalar).
// note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization,
// so we have to check the shape.
const bool update_axis = scale_shape && !scale_shape->empty();
// Q uses per-axis quantization if its scale input is not a scalar and not a tensor with shape (1,).
// Note there could be an axis value as the onnx spec says that is ignored for per-tensor quantization,
// so we have to check the scale input's shape.
const bool update_axis = !IsScalarOr1Element1DTensor(*scale_shape);
int64_t axis = q_node.GetAttributeIntDefault("axis", 1);
if (update_axis) {
auto perm = GetPermAttrIfValid(transpose_node);
assert(perm.has_value()); // onnx shape inferencing checks that `perm` is valid
NormalizeAndValidateAxis(axis, scale_shape->size());
const auto q_input0_info = ctx.graph.GetValueInfo(q_node_inputs[0]);
std::optional<size_t> q_input0_rank = q_input0_info->ShapeRank();
if (!q_input0_rank.has_value() || !NormalizeAndValidateAxis(axis, *q_input0_rank)) {
return false; // Unable to normalize the Q's axis.
}
axis = (*perm)[gsl::narrow_cast<size_t>(axis)]; // Note: do not invert permutation.
}

View file

@ -105,6 +105,12 @@ class ValueInfoRef {
/// </returns>
virtual std::optional<std::vector<int64_t>> Shape() const = 0;
/// <returns>
/// The inferred/declared rank of the value's tensor shape, or nullopt if the rank is unknown. A scalar
/// has a rank of 0.
/// </returns>
virtual std::optional<size_t> ShapeRank() const = 0;
/// <returns>The inferred/declared dtype of the value. UNDEFINED (0) if dtype is unknown.</returns>
virtual DataType DType() const = 0;

View file

@ -33,6 +33,7 @@ class ApiValueInfo final : public api::ValueInfoRef {
explicit ApiValueInfo(NodeArg& node_arg) : node_arg_(node_arg) {}
std::string_view Name() const override;
std::optional<std::vector<int64_t>> Shape() const override;
std::optional<size_t> ShapeRank() const override;
api::DataType DType() const override;
void SetShape(const std::vector<int64_t>* shape) override;
@ -184,6 +185,15 @@ std::optional<std::vector<int64_t>> ApiValueInfo::Shape() const {
return result;
}
std::optional<size_t> ApiValueInfo::ShapeRank() const {
const auto* shape_proto = GetNodeArgShape(&node_arg_);
if (shape_proto == nullptr) {
return std::nullopt;
}
return static_cast<size_t>(shape_proto->dim_size());
}
api::DataType ApiValueInfo::DType() const {
const auto* type = node_arg_.TypeAsProto();
if (!type) {

View file

@ -4424,6 +4424,41 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) {
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>()));
}
// regression test for a model with DQ node with per-axis dequantization followed by a Transpose.
// Tests handling of a negative DQ axis.
// see https://github.com/microsoft/onnxruntime/issues/12151 for more details.
TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151_NegativeDQAxis) {
Status status;
auto model_uri = ORT_TSTR("testdata/ort_github_issue_12151_neg_dq_axis.onnx");
NameMLValMap feeds; // no inputs for this model
std::vector<std::string> output_names{"Z"};
std::vector<OrtValue> fetches_orig;
std::vector<OrtValue> fetches;
SessionOptions so;
so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue12151_NegativeDQAxis";
{
so.graph_optimization_level = TransformerLevel::Default; // off
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig));
}
{
so.graph_optimization_level = TransformerLevel::Level1; // enable transpose optimizer
InferenceSession session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches));
}
ASSERT_THAT(fetches_orig[0].Get<Tensor>().DataAsSpan<float>(),
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>()));
}
// These tests use the internal testing EP with static kernels which requires a full build and contrib ops,
// and the NHWC Conv which requires contrib ops
#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS)
@ -4813,6 +4848,89 @@ TEST(TransposeOptimizerTests, ConstantFoldTransposeAndSqueezeOutputCorrectness)
testing::ContainerEq(fetches[1].Get<Tensor>().DataAsSpan<float>()));
}
// Tests the fix-up of a QDQ NodeUnit containing a per-channel DQ followed by an Unsqueeze.
// Before: DQ (axis = 0) -> Unsqueeze (axes = [0, 1, 2]) -> Op
// After: DQ (axis = 0) -> Unsqueeze (axes = [0, 1, 2]) -> Q (axis = 3) -> DQ (axis = 3) -> Op
TEST(TransposeOptimizerTests, FixQDQNodeUnitWithPerChannelDQUnsqueeze) {
// Test model contains a Mul with a broadcastable/constant/per-channel DQ input. When a transpose is pushed through
// the Mul, the contant DQ input is Unsqueezed.
auto model_uri = ORT_TSTR("testdata/transpose_optimization_unsqueeze_dq_axis.qdq.onnx");
RandomValueGenerator random{123};
std::vector<int64_t> input_dims{1, 3, 4, 4};
std::vector<float> input0_data = random.Gaussian<float>(input_dims, 0.0f, 1.0f);
OrtValue input0;
CreateMLValue<float>(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], input_dims, input0_data, &input0);
NameMLValMap feeds{{"input0", input0}};
std::vector<std::string> output_names{"output0"};
std::vector<OrtValue> fetches_orig;
std::vector<OrtValue> fetches;
SessionOptions so;
ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1"));
so.graph_optimization_level = TransformerLevel::Default; // off
// get results with no modifications to the model
{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig));
}
{
InferenceSessionWrapper session{so, GetEnvironment()};
ASSERT_STATUS_OK(session.Load(model_uri));
// We call the ONNX transpose optimizer directly to use a custom cost check function.
Graph& graph = session.GetMutableGraph();
CPUAllocator allocator;
namespace alias_oto = onnx_transpose_optimization;
auto api_graph = MakeApiGraph(graph,
TestCPUExecutionProvider()->CreatePreferredAllocators()[0],
/*new_node_ep*/ nullptr);
// Use a custom optimization cost check that aggressively pushes channel-last or channel-first transposes.
auto custom_cost_fn =
[](const alias_oto::api::GraphRef& /* graph */,
const alias_oto::api::NodeRef& /* node */,
const std::vector<int64_t>& perm,
const std::unordered_set<std::string>& /* outputs_leading_to_transpose */) -> alias_oto::CostCheckResult {
if (perm == alias_oto::ChannelFirstToLastPerm(perm.size()) ||
perm == alias_oto::ChannelLastToFirstPerm(perm.size())) {
return alias_oto::CostCheckResult::kPushTranspose;
}
return alias_oto::CostCheckResult::kFallThrough;
};
alias_oto::OptimizeResult result = alias_oto::Optimize(*api_graph, /*provider_type*/ "", custom_cost_fn);
ASSERT_EQ(result.error_msg, std::nullopt);
ASSERT_TRUE(result.graph_modified);
ASSERT_TRUE(graph.GraphResolveNeeded());
ASSERT_STATUS_OK(graph.Resolve());
// Use this hack to save model for viewing if needed
// ASSERT_STATUS_OK(Model::Save(const_cast<Model&>(session.GetModel()),
// ToPathString("transpose_optimization_unsqueeze_dq_axis.qdq.updated.onnx")));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
EXPECT_EQ(op_to_count["Unsqueeze"], 1) << "1 Unsqueeze node added to broadcastable Mul weight.";
EXPECT_EQ(op_to_count["Transpose"], 1) << "2 Transposes at the I/O cancel. 1 Transpose inserted above Mul weight.";
ASSERT_STATUS_OK(session.Initialize());
ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches));
}
ASSERT_THAT(fetches_orig[0].Get<Tensor>().DataAsSpan<float>(),
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>()));
}
static void CheckSharedInitializerHandling(bool broadcast) {
auto model_uri = broadcast ? ORT_TSTR("testdata/transpose_optimizer_shared_initializers_broadcast.onnx")
: ORT_TSTR("testdata/transpose_optimizer_shared_initializers.onnx");

View file

@ -0,0 +1,103 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import numpy as np
import onnx
if __name__ == "__main__":
"""
Creates a QDQ model with a per-channel DQ weight that is Unsqueezed and Transposed by the Transpose optimizer.
"""
input0_shape = (1, 3, 4, 4)
input0 = onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, input0_shape)
output0 = onnx.helper.make_tensor_value_info("output0", onnx.TensorProto.FLOAT, None)
scale_1 = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "scale_1")
zp_128 = onnx.numpy_helper.from_array(np.array(128, dtype=np.uint8), "zp_128")
scale_inv_255 = onnx.numpy_helper.from_array(np.array(1.0 / 255.0, dtype=np.float32), "scale_inv_255")
zp_0 = onnx.numpy_helper.from_array(np.array(0, dtype=np.uint8), "zp_0")
mul_weight_i8_data = np.array([1, 2, 3], dtype=np.int8)
mul_weight_scales_data = np.array([1.0, 1.0, 1.0], dtype=np.float32)
mul_weight_zps_data = np.array([0, 0, 0], dtype=np.int8)
mul_weight = onnx.numpy_helper.from_array(mul_weight_i8_data, "mul_weight")
mul_weight_scales = onnx.numpy_helper.from_array(mul_weight_scales_data, "mul_weight_scales")
mul_weight_zps = onnx.numpy_helper.from_array(mul_weight_zps_data, "mul_weight_zps")
# Transpose to channel-last
tp0_node = onnx.helper.make_node("Transpose", ["input0"], ["tp0_out"], name="tp0_node", perm=(0, 2, 3, 1))
# Q_0
q0_node = onnx.helper.make_node("QuantizeLinear", ["tp0_out", "scale_1", "zp_128"], ["q0_out"], name="q0_node")
# DQ_0
dq0_node = onnx.helper.make_node("DequantizeLinear", ["q0_out", "scale_1", "zp_128"], ["dq0_out"], name="dq0_node")
# Sigmoid
sigmoid_node = onnx.helper.make_node("Sigmoid", ["dq0_out"], ["sigmoid_out"], name="sigmoid_node")
# Q_1
q1_node = onnx.helper.make_node(
"QuantizeLinear", ["sigmoid_out", "scale_inv_255", "zp_0"], ["q1_out"], name="q1_node"
)
# DQ_1
dq1_node = onnx.helper.make_node(
"DequantizeLinear", ["q1_out", "scale_inv_255", "zp_0"], ["dq1_out"], name="dq1_node"
)
# DQ_weight
dq_weight_node = onnx.helper.make_node(
"DequantizeLinear",
["mul_weight", "mul_weight_scales", "mul_weight_zps"],
["dq_weight_out"],
name="dq_weight_node",
axis=0,
)
# Mul
mul_node = onnx.helper.make_node("Mul", ["dq1_out", "dq_weight_out"], ["mul_out"], name="mul_node")
# Q_2
q2_node = onnx.helper.make_node("QuantizeLinear", ["mul_out", "scale_inv_255", "zp_0"], ["q2_out"], name="q2_node")
# DQ_2
dq2_node = onnx.helper.make_node(
"DequantizeLinear", ["q2_out", "scale_inv_255", "zp_0"], ["dq2_out"], name="dq2_node"
)
# Transpose to channel-first
tp1_node = onnx.helper.make_node("Transpose", ["dq2_out"], ["output0"], name="tp1_node", perm=(0, 3, 1, 2))
graph = onnx.helper.make_graph(
[
tp0_node,
q0_node,
dq0_node,
sigmoid_node,
q1_node,
dq1_node,
dq_weight_node,
mul_node,
q2_node,
dq2_node,
tp1_node,
],
"transpose_opt_unsqueeze_dq_axis",
[input0],
[output0],
initializer=[scale_1, zp_128, scale_inv_255, zp_0, mul_weight, mul_weight_scales, mul_weight_zps],
)
opset_imports = [
onnx.helper.make_opsetid("", 19),
]
qdq_model = onnx.helper.make_model(graph, opset_imports=opset_imports)
print("[INFO]: Running onnx.checker on qdq model")
qdq_model = onnx.shape_inference.infer_shapes(qdq_model)
onnx.checker.check_model(qdq_model, True)
qdq_model_path = "transpose_optimization_unsqueeze_dq_axis.qdq.onnx"
print(f"[INFO]: Saving {qdq_model_path}")
onnx.save_model(qdq_model, qdq_model_path)

Binary file not shown.