mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Fix GH issue 12151 by using inverse perms for updating DQ axis attribute (#12158)
* Fix GH issue 12151. Need to use inverse perms for updating that axis to what is used for transposing the input. This only applies if the DQ node is doing per-axis dequantization.
This commit is contained in:
parent
785f74979b
commit
75cf5dc2c9
3 changed files with 41 additions and 2 deletions
|
|
@ -1997,7 +1997,9 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) {
|
|||
continue;
|
||||
}
|
||||
|
||||
if (!HandleQuantizeDequantizeScale(ctx.graph, *perm, *dq_node, ctx.opset)) {
|
||||
// we're moving the Transpose to before the DQ, so we need to use the inverse permutations to update the axis
|
||||
// attribute correctly when doing per-axis dequantization
|
||||
if (!HandleQuantizeDequantizeScale(ctx.graph, InvertPerm(*perm), *dq_node, ctx.opset)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
#include "graph_transform_test_builder.h"
|
||||
|
||||
#include "core/graph/graph.h"
|
||||
|
|
@ -3620,7 +3622,6 @@ TEST(TransposeOptimizerTests, TestDequantizeLinearTransposePropagation) {
|
|||
EXPECT_EQ(op_types_in_order, expected_op_types_in_order);
|
||||
};
|
||||
|
||||
|
||||
TransformerTester(build_test_case_1,
|
||||
check_graph,
|
||||
TransformerLevel::Default,
|
||||
|
|
@ -4047,5 +4048,41 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue10305) {
|
|||
ASSERT_STATUS_OK(session_object.Load(model_uri));
|
||||
ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization
|
||||
}
|
||||
|
||||
// regression test for a model with DQ node with per-axis dequantization followed by a Transpose.
|
||||
// the second phase can swap those around, but needs to use the correct perms for updating the 'axis'
|
||||
// attribute in the DQ node.
|
||||
// see https://github.com/microsoft/onnxruntime/issues/12151 for more details.
|
||||
TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) {
|
||||
Status status;
|
||||
auto model_uri = ORT_TSTR("testdata/ort_github_issue_12151.onnx");
|
||||
|
||||
NameMLValMap feeds; // no inputs for this model
|
||||
std::vector<std::string> output_names{"Z"};
|
||||
std::vector<OrtValue> fetches_orig;
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue12151";
|
||||
|
||||
{
|
||||
so.graph_optimization_level = TransformerLevel::Default; // off
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.Load(model_uri));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches_orig));
|
||||
}
|
||||
|
||||
{
|
||||
so.graph_optimization_level = TransformerLevel::Level1; // enable transpose optimizer
|
||||
InferenceSession session_object{so, GetEnvironment()};
|
||||
ASSERT_STATUS_OK(session_object.Load(model_uri));
|
||||
ASSERT_STATUS_OK(session_object.Initialize());
|
||||
ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches));
|
||||
}
|
||||
|
||||
ASSERT_THAT(fetches_orig[0].Get<Tensor>().DataAsSpan<float>(),
|
||||
testing::ContainerEq(fetches[0].Get<Tensor>().DataAsSpan<float>()));
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/ort_github_issue_12151.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/ort_github_issue_12151.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue