From 724009289b74315ea70d856c19e9353a080b9086 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 12 Nov 2021 07:18:16 +1000 Subject: [PATCH] Fix Issue #9671 (#9691) * Fix #9671 by running the level 1 rewrite rules first and allowing the transpose optimizer to run multiple times to ensure it completes in level 1. Removed unnecessary call to GenerateRuleBasedGraphTransformer as there are no level 2 rewrite rules. --- .../core/optimizer/graph_transformer_utils.cc | 24 +++++++++--------- .../ort_transpose_optimizer.h | 4 --- .../optimizer/transpose_optimizer_test.cc | 20 ++++++++++++--- onnxruntime/test/testdata/gh_issue_9671.onnx | Bin 0 -> 6585 bytes 4 files changed, 29 insertions(+), 19 deletions(-) create mode 100644 onnxruntime/test/testdata/gh_issue_9671.onnx diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 3fd9434726..84a9819d55 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -146,14 +146,23 @@ std::vector> GenerateTransformers( const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ const std::unordered_set& rules_and_transformers_to_disable) { std::vector> transformers; - std::unique_ptr rule_transformer = nullptr; - bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; + bool disable_quant_qdq = + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; #ifndef DISABLE_CONTRIB_OPS - bool enable_gelu_approximation = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableGeluApproximation, "0") == "1"; + bool enable_gelu_approximation = + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableGeluApproximation, "0") == "1"; #endif switch (level) { case TransformerLevel::Level1: { + // RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run) + // so run them first so there is potentially less for the more intensive optimizations like ConstantFolding, + // CommonSubexpressionElimination and TransposeOptimizer to do. + auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}); + if (rule_transformer != nullptr) { + transformers.emplace_back(std::move(rule_transformer)); + } + // no filtering on execution provider for L1 optimizations as they only use official ONNX operators transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq)); @@ -163,16 +172,11 @@ std::vector> GenerateTransformers( session_options.free_dimension_overrides)); auto cpu_allocator = cpu_execution_provider.GetAllocator(0, OrtMemTypeDefault); transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); - - rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}); } break; case TransformerLevel::Level2: { std::unordered_set cpu_ep = {onnxruntime::kCpuExecutionProvider}; - // create rule based transformer consisting of all the level2 rewrite rules - rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, cpu_ep); - #ifndef DISABLE_CONTRIB_OPS const std::unordered_set cuda_rocm_eps = {onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider}; @@ -238,10 +242,6 @@ std::vector> GenerateTransformers( ORT_THROW("Unsupported optimization level: ", static_cast(level)); } - if (rule_transformer != nullptr) { - transformers.emplace_back(std::move(rule_transformer)); - } - FilterTransformers(transformers, rules_and_transformers_to_disable); return transformers; diff --git a/onnxruntime/core/optimizer/transpose_optimizer/ort_transpose_optimizer.h b/onnxruntime/core/optimizer/transpose_optimizer/ort_transpose_optimizer.h index 46816caeee..9535e23a18 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer/ort_transpose_optimizer.h +++ b/onnxruntime/core/optimizer/transpose_optimizer/ort_transpose_optimizer.h @@ -21,10 +21,6 @@ class TransposeOptimizer : public GraphTransformer { : GraphTransformer("TransposeOptimizer"), cpu_allocator_(std::move(cpu_allocator)) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; - - // One run should be sufficient. Multiple runs should be ok but are prohibited to prevent any possibility of an - // infinite loop. - bool ShouldOnlyApplyOnce() const override { return true; } }; } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index bc55562214..c4afac1b02 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -8,11 +8,12 @@ #include "graph_transform_test_builder.h" #include "core/graph/graph.h" +#include "test/test_environment.h" +#include "test/util/include/asserts.h" namespace onnxruntime { namespace test { - void SetNodeArgShape(NodeArg* node_arg, const std::optional>& shape) { if (shape == std::nullopt) { node_arg->ClearShape(); @@ -77,7 +78,6 @@ int EstimateTransposeCost(const Graph& graph) { return cost; } - TEST(TransposeOptimizerTests, TestSplit) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = builder.MakeInput({4, 6, 10}, 0.0, 1.0); @@ -3849,8 +3849,22 @@ TEST(TransposeOptimizerTests, TestOmitIdentityTranspose) { /*opset_version*/ 15); } +// regression test for a model where the transpose optimizations were not completed in a single pass in level 1. +// fixed by +// a) moving the RewriteRule level 1 optimizations so they run prior to the transpose optimizer; and +// b) not returning `true` from TransposeOptimizer::ShouldOnlyApplyOnce as it should be safe to run the +// transpose optimizer multiple times to ensure it completes in level 1. +// either of those changes would have fixed the issue. +// see https://github.com/microsoft/onnxruntime/issues/9671 for more details. +TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue9671) { + auto model_uri = ORT_TSTR("testdata/gh_issue_9671.onnx"); - + SessionOptions so; + so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue9671"; + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_uri)); + ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/gh_issue_9671.onnx b/onnxruntime/test/testdata/gh_issue_9671.onnx new file mode 100644 index 0000000000000000000000000000000000000000..53f1215402add38ae3bfe03279c225753f817426 GIT binary patch literal 6585 zcmb_h-EQPG6i(XdOzJMQL|RnB-;%2(NK5SeOa?9}yArD%X(hxi61x{fnU32T=uD?G znWf7WZvYoOz+P~{Yw%=XJC0)~ak>Ipsd3`t^PO|-bNt!02KN5qVzs`U&QEh^;NSD_ zP4NrYxN9Kev)=;%!OnNkejCr1o9zN7aUA(Wn4kzH>3r(!KbqoYKF?qJu2nX{6llc3 zyX|Vy>>6)Wj2MN(aEMZT?SOBWo2xCxKV#_bi#GUPMX3ff6-CHT^R)w>uCwK4b-BUF z-4|`|Ht4PJdU0Z&7$-X?yC=Pqy;tS{Y@7NEJ*pQ3H;y1m*SOsa-symE@N|2MpKKR_ zyD!?_A+UbLcr~3bHpk{GbBDznsV)qW&$P(#4m;v4#8pcxHEov0D5CBu@s7IV&0SNC zjbk)SP^iW`Zi_dKHiLuteC?0G$Ld<`R9)HFw19n z4Sly?w7{kt6TJ?m&T5J|sX(CsJp-L)+DcK7!cP@Cw~AExn$#zVT%JVqt(iG9t#gHG zc9$6w1u5u8`qfl3t7bHKM>8!c1jm{#sI_j5H#6v14=$IdSq?+DU$ndf(95ppn`47k zjo-}9AW!~(laMeeYBy=F%Qvm1%u{U^Wkn!J?JV*#b_cu(nov9?JE2fF6mkUAu!8zQ zyU3V*$iAbQ;G&zBtpl<3z(lTFhkPHRb&+fZ18~`m1;LG2wxh9TdYL=c<1Bx?y?`UP zU$jUZyHey_YYxw^4>Vwl>^dbP&feqf8u{)ndE8WqAox(D>;j+V=?3R^rvQcH zv0AhS2=&?mb~ayT7b2NZSS1tIt~WMD-S|htXvB#{l|hJ56p>DX$r+5rnIw-GlS72U z2C6d3Qo@8$SWxm-6ib{6H>+16->CdtPdlfIqhU-r)q&uxh;SA00$wy=_=pBfk_0-E z86)2590+2GNNEu-&Al>h;Bpmem`1}GCL~Ns)k358m{TT4869JD4^@C_^3WR?3!-8#M{iIDuO_*PJ5evNNQBY444IG)rk zVJdt}5}}wC-wH@lbCOr%TeTtQD&MO0$}A`dyID|;Z`CwS&eeiY?^ckaZlu~qD5)x_ zfLcD~?(L$;ONC=f0wjopTNQx|ViCd{0MUN}B+Q8e4h7&#wf0IDran|_fz(H>SNB1w zlhLIV3hEdn?PdR8L!(Y;BtsL5(Z*UYbVWiN-_j{<^4C~m=vTW<1FftiooLjLZnTOh zhOk2W(^ScZG&6;b-Q7r=rssIEaQY9)-w?Sy`XuFjFS}UHvfGC7%NMUq(|Z&2COD_> z200#>-W#Bw)4Pu^*=a$%-Y>2kYP3(Xl+`t@p+yC<6>cL9mT6xYw9-CF`i4wyr+w0R zCO|AIlPe@AC0-C~v`-qutFS`qq;3uEYe>kjrm72~mLE1`A7?Rr8mZ4W#{d5^>%@g{ydw@y`L+ON7BZZRFEu@g}IC25UZucru z7%IoA@_Nr8)PxX1LJisBDpPAurAYO-`rc0PlE4qs&6TtNq`0CCNA7?(TW*TDv`VrF zL=uG|y-pQeG?MRAVZ&4|Z=27gA+VyacE311d7oRmM}zyy?pNL3GLOc`yR=G>pI-4W z_W0BL#4_uHYqy