From 75cf5dc2c9260e082172ea4e8c5384aef6e4cfe2 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 13 Jul 2022 18:02:58 +1000 Subject: [PATCH] Fix GH issue 12151 by using inverse perms for updating DQ axis attribute (#12158) * Fix GH issue 12151. Need to use inverse perms for updating that axis to what is used for transposing the input. This only applies if the DQ node is doing per-axis dequantization. --- .../transpose_optimizer.cc | 4 +- .../optimizer/transpose_optimizer_test.cc | 39 +++++++++++++++++- .../test/testdata/ort_github_issue_12151.onnx | Bin 0 -> 380 bytes 3 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/test/testdata/ort_github_issue_12151.onnx diff --git a/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc index 9e552a41c3..81adcee7cb 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc @@ -1997,7 +1997,9 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { continue; } - if (!HandleQuantizeDequantizeScale(ctx.graph, *perm, *dq_node, ctx.opset)) { + // we're moving the Transpose to before the DQ, so we need to use the inverse permutations to update the axis + // attribute correctly when doing per-axis dequantization + if (!HandleQuantizeDequantizeScale(ctx.graph, InvertPerm(*perm), *dq_node, ctx.opset)) { continue; } diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 68e2be1b34..cfeb9a2202 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -5,6 +5,8 @@ #include #include "gtest/gtest.h" +#include "gmock/gmock.h" + #include "graph_transform_test_builder.h" #include "core/graph/graph.h" @@ -3620,7 +3622,6 @@ TEST(TransposeOptimizerTests, TestDequantizeLinearTransposePropagation) { EXPECT_EQ(op_types_in_order, expected_op_types_in_order); }; - TransformerTester(build_test_case_1, check_graph, TransformerLevel::Default, @@ -4047,5 +4048,41 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue10305) { ASSERT_STATUS_OK(session_object.Load(model_uri)); ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization } + +// regression test for a model with DQ node with per-axis dequantization followed by a Transpose. +// the second phase can swap those around, but needs to use the correct perms for updating the 'axis' +// attribute in the DQ node. +// see https://github.com/microsoft/onnxruntime/issues/12151 for more details. +TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) { + Status status; + auto model_uri = ORT_TSTR("testdata/ort_github_issue_12151.onnx"); + + NameMLValMap feeds; // no inputs for this model + std::vector output_names{"Z"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue12151"; + + { + so.graph_optimization_level = TransformerLevel::Default; // off + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); + ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches_orig)); + } + + { + so.graph_optimization_level = TransformerLevel::Level1; // enable transpose optimizer + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); + ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/ort_github_issue_12151.onnx b/onnxruntime/test/testdata/ort_github_issue_12151.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f796b46f1bdc26302f4a23513f16cc8ebf77d15b GIT binary patch literal 380 zcmYk2UrWMJ0L6Fj=DfR8n-?v$5>X_W*l0wANY1tJ-%O$9zlU_$9tO49wm^L9t*_As zs?X8aXcYZ<_`!k0;T#T%66`w8x*Ax1;MYDN)-aghh`%)NEsQso=gwEI?F0l&bNJ4B zd@`ND-dIIJ`_;PbSf(fPm@(J3p8>A`VhMvO0ka$zc&6j8T;D2k#*Y+}m|0oggF-1# zp_MQYCF!N>_(Aor;cJoj0uV$rVo*eYDwuTF@6(E1uEE9Wje&0jS@{bYm2yG$`u1@&2ZtiRsXS>c6wP cG|kAjyBN&KHclZx1Pl=jVL~DyvIVE_OC literal 0 HcmV?d00001