Fix GH issue 12151 by using inverse perms for updating DQ axis attribute (#12158)

* Fix GH issue 12151.

Need to use inverse perms for updating that axis to what is used for transposing the input. This only applies if the DQ node is doing per-axis dequantization.
This commit is contained in:
Scott McKay 2022-07-13 18:02:58 +10:00 committed by GitHub
parent 785f74979b
commit 75cf5dc2c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 2 deletions

View file

@ -1997,7 +1997,9 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) {
continue;
}
if (!HandleQuantizeDequantizeScale(ctx.graph, *perm, *dq_node, ctx.opset)) {
// we're moving the Transpose to before the DQ, so we need to use the inverse permutations to update the axis
// attribute correctly when doing per-axis dequantization
if (!HandleQuantizeDequantizeScale(ctx.graph, InvertPerm(*perm), *dq_node, ctx.opset)) {
continue;
}

View file

@ -5,6 +5,8 @@
#include <vector>
#include "gtest/gtest.h"
#include "gmock/gmock.h"
#include "graph_transform_test_builder.h"
#include "core/graph/graph.h"
@ -3620,7 +3622,6 @@ TEST(TransposeOptimizerTests, TestDequantizeLinearTransposePropagation) {
EXPECT_EQ(op_types_in_order, expected_op_types_in_order);
};
TransformerTester(build_test_case_1,
check_graph,
TransformerLevel::Default,
@ -4047,5 +4048,41 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue10305) {
ASSERT_STATUS_OK(session_object.Load(model_uri));
ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization
}
// regression test for a model with DQ node with per-axis dequantization followed by a Transpose.
// the second phase can swap those around, but needs to use the correct perms for updating the 'axis'
// attribute in the DQ node.
// see https://github.com/microsoft/onnxruntime/issues/12151 for more details.
TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) {
Status status;
auto model_uri = ORT_TSTR("testdata/ort_github_issue_12151.onnx");
NameMLValMap feeds; // no inputs for this model
std::vector<std::string> output_names{"Z"};
std::vector<OrtValue> fetches_orig;
std::vector<OrtValue> fetches;
SessionOptions so;
so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue12151";
{
so.graph_optimization_level = TransformerLevel::Default; // off
InferenceSession session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_uri));
ASSERT_STATUS_OK(session_object.Initialize());
ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches_orig));
}
{
so.graph_optimization_level = TransformerLevel::Level1; // enable transpose optimizer
InferenceSession session_object{so, GetEnvironment()};
ASSERT_STATUS_OK(session_object.Load(model_uri));
ASSERT_STATUS_OK(session_object.Initialize());
ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches));
}
ASSERT_THAT(fetches_orig[0].Get<Tensor>().DataAsSpan<float>(),
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>()));
}
} // namespace test
} // namespace onnxruntime

Binary file not shown.