From 2e1fa3ccb75a6441ee0a8a6da92c396daac1816f Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 29 Oct 2020 00:07:13 -0700 Subject: [PATCH] Fix GeluRecompute for 2 inputs case. (#5573) * Add test for FastGelu + GeluRecompute. * Fix GeluRecompute for 2 inputs case. * Fix test for BiasGelu + GeluRecompute. * Copy all inputs to Gelu, not just 2. * Move GeluRecompute test to training-specific file. --- include/onnxruntime/core/graph/graph.h | 2 +- .../fusion/bias_gelu_fusion_recompute.onnx | Bin 0 -> 478 bytes .../transform/fusion/bias_gelu_matmul_gen.py | 36 ++++++++++++++++++ .../core/optimizer/localized_recompute.cc | 2 +- .../test/optimizer/graph_transform_test.cc | 30 +++++++++++++++ 5 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion_recompute.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/bias_gelu_matmul_gen.py diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 0c153350cf..9bd072d5fa 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -108,7 +108,7 @@ class Node { /** Gets the domain of the OperatorSet that specifies the operator returned by #OpType. */ const std::string& Domain() const noexcept { return domain_; } - /** Gets the Node's exection priority. + /** Gets the Node's execution priority. @remarks Lower value means higher priority */ int Priority() const noexcept { return priority_; }; diff --git a/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion_recompute.onnx b/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion_recompute.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0e083e22567d430e7068cd6233ae1f45f6e7b80c GIT binary patch literal 478 zcmZ8dyH3ME5cF|QY%dhnQ9x`E$`uHxFc1nF6vsf7hAK*PGO;6B_<<5%>iiU+!k6#^ z?AbYk=ti?Mv$Jbk4__2~AY&jQv*RSW7L}|$Qz1^rNdgzp(cmQ6Ml7mwT{BBTC8Y>B zjj|0~{#6z-*GdJOPy`q{l{uV%_A@J8E!!XjoIb6V7^F#Xu*fb))d;$FLECN!)-JwT z-pm=9cqXzaQM(2}W p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Div"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Erf"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 2); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "com.microsoft.BiasGelu") { + ASSERT_TRUE(node.InputDefs().size() == 2); + } + } +} + // We only tested on CUDA run. #if defined(USE_CUDA) TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {