diff --git a/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc index 9e552a41c3..81adcee7cb 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer/transpose_optimizer.cc @@ -1997,7 +1997,9 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { continue; } - if (!HandleQuantizeDequantizeScale(ctx.graph, *perm, *dq_node, ctx.opset)) { + // we're moving the Transpose to before the DQ, so we need to use the inverse permutations to update the axis + // attribute correctly when doing per-axis dequantization + if (!HandleQuantizeDequantizeScale(ctx.graph, InvertPerm(*perm), *dq_node, ctx.opset)) { continue; } diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 68e2be1b34..cfeb9a2202 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -5,6 +5,8 @@ #include #include "gtest/gtest.h" +#include "gmock/gmock.h" + #include "graph_transform_test_builder.h" #include "core/graph/graph.h" @@ -3620,7 +3622,6 @@ TEST(TransposeOptimizerTests, TestDequantizeLinearTransposePropagation) { EXPECT_EQ(op_types_in_order, expected_op_types_in_order); }; - TransformerTester(build_test_case_1, check_graph, TransformerLevel::Default, @@ -4047,5 +4048,41 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue10305) { ASSERT_STATUS_OK(session_object.Load(model_uri)); ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization } + +// regression test for a model with DQ node with per-axis dequantization followed by a Transpose. +// the second phase can swap those around, but needs to use the correct perms for updating the 'axis' +// attribute in the DQ node. +// see https://github.com/microsoft/onnxruntime/issues/12151 for more details. +TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) { + Status status; + auto model_uri = ORT_TSTR("testdata/ort_github_issue_12151.onnx"); + + NameMLValMap feeds; // no inputs for this model + std::vector output_names{"Z"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue12151"; + + { + so.graph_optimization_level = TransformerLevel::Default; // off + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); + ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches_orig)); + } + + { + so.graph_optimization_level = TransformerLevel::Level1; // enable transpose optimizer + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); + ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/ort_github_issue_12151.onnx b/onnxruntime/test/testdata/ort_github_issue_12151.onnx new file mode 100644 index 0000000000..f796b46f1b Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_12151.onnx differ