From 9eba9fba7ce4b0ce7e982313116a766a3ae17c2d Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Tue, 9 Jun 2020 14:33:34 -0700 Subject: [PATCH] Fix for BiasGelu fusion optimizer (#4160) * Fix for BiasGelu fusion optimizer * changes per review comments --- onnxruntime/core/graph/graph_utils.cc | 13 +++- onnxruntime/core/graph/graph_utils.h | 3 + .../test/optimizer/graph_transform_test.cc | 57 +++++++++++++++++- .../fusion/bias_gelu_fusion_format_2.onnx | Bin 0 -> 505 bytes 4 files changed, 69 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion_format_2.onnx diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index c63d003e24..fd6d1fbfcb 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -234,7 +234,8 @@ static void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_nod auto input_edges = GetNodeInputEdges(src_node); for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) { - graph.AddEdge(cur->src_node, target_idx, cur->src_arg_index, cur->dst_arg_index); + auto target_arg_index = GetNodeInputIndexFromInputName(target_node, cur->arg_name); + graph.AddEdge(cur->src_node, target_idx, cur->src_arg_index, target_arg_index); } RemoveGraphEdges(graph, input_edges); @@ -261,6 +262,16 @@ static void MoveAllNodeOutputs(Graph& graph, Node& src_node, Node& target_node) //--- end of local helpers --- //---------------------------- +int GetNodeInputIndexFromInputName(const Node& node, const std::string& input_name) { + auto itr = std::find_if(node.InputDefs().begin(), node.InputDefs().end(), + [&input_name](const NodeArg* input) { return input->Name() == input_name; }); + ORT_ENFORCE(itr != node.InputDefs().end(), + "Attempting to get index for an input which does not exist."); + auto index = std::distance(node.InputDefs().begin(), itr); + return static_cast(index); + +} + const std::string& GetNodeInputName(const Node& node, int index) { const auto& inputs = node.InputDefs(); ORT_ENFORCE(index >= 0 && static_cast(index) < inputs.size(), diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index ec4a5388b2..b5d927eff0 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -69,6 +69,9 @@ bool AllNodeInputsAreConstant(const Graph& graph, const Node& node, InitializedT /** Gets the name of the incoming NodeArg with the specified index for the given node. */ const std::string& GetNodeInputName(const Node& node, int index); +/** Gets the index of an input arg with the specified input arg name. */ +int GetNodeInputIndexFromInputName(const Node& node, const std::string& input_name); + /** Gets the name of the outgoing NodeArg with the specified index for the given node. */ const std::string& GetNodeOutputName(const Node& node, int index); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 2a9b1bb686..722bdb1ece 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -48,6 +48,8 @@ #include "test/capturing_sink.h" #include "test/framework/test_utils.h" #include "test/optimizer/graph_transform_test_fixture.h" +#include "test/compare_ortvalue.h" +#include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/test_environment.h" #include "asserts.h" @@ -1199,7 +1201,7 @@ TEST_F(GraphTransformationTests, ReshapeFusionGraphInputsTest) { ASSERT_EQ(op_to_count["Concat"], 1); ASSERT_EQ(op_to_count["Reshape"], 1); } - + TEST_F(GraphTransformationTests, ReshapeFusionMultipleValuesInInitializerSubgraphTest) { auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_multiple_values_in_initializer_tensor_1.onnx"; std::shared_ptr p_model; @@ -1399,8 +1401,6 @@ TEST_F(GraphTransformationTests, ExpandElimination) { ASSERT_TRUE(op_to_count["Expand"] == 3); } - - TEST_F(GraphTransformationTests, CastElimination) { auto model_uri = MODEL_FOLDER "cast_elimination.onnx"; std::shared_ptr model; @@ -1721,6 +1721,57 @@ TEST_F(GraphTransformationTests, BiasGeluTest) { ASSERT_TRUE(op_to_count["BiasGelu"] == 1); } + +// BiasGelu allows input switching based on input dimensions. +// This test validates the input edges are plugged correct in the optimized graph. +TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) { + auto model_uri = MODEL_FOLDER "fusion/bias_gelu_fusion_format_2.onnx"; + + // create inputs and outputs + RandomValueGenerator random{}; + NameMLValMap feeds; + + OrtValue mlvalue_b_i; + std::vector dims_b_i = {3072}; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_b_i, + random.Uniform(dims_b_i, 0.0f, 1.0f), &mlvalue_b_i); + feeds.insert(std::make_pair("B_I", mlvalue_b_i)); + + OrtValue mlvalue_a_i; + std::vector dims_a_i = {3, 512, 3072}; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_a_i, + random.Uniform(dims_a_i, 0.0f, 1.0f), &mlvalue_a_i); + feeds.insert(std::make_pair("A_I", mlvalue_a_i)); + + std::vector output_names; + output_names.push_back("C"); + + auto run_model_test = [&](TransformerLevel level, std::vector& fetches) { + SessionOptions session_options; + session_options.graph_optimization_level = level; + session_options.session_logid = "OptimizerTests"; + InferenceSession session{session_options, GetEnvironment()}; + ASSERT_TRUE(session.Load(model_uri).IsOK()); + ASSERT_TRUE(session.Initialize().IsOK()); + + RunOptions run_options; + ASSERT_STATUS_OK(session.Run(run_options, feeds, output_names, &fetches)); + }; + + // run model with and w/o optimizations and compare the results + std::vector unoptimized_fetches; + run_model_test(TransformerLevel::Default, unoptimized_fetches); + + std::vector optimized_fetches; + run_model_test(TransformerLevel::MaxLevel, optimized_fetches); + + // Compare results + double per_sample_tolerance = 1e-3; + double relative_per_sample_tolerance = 0.0; + auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], per_sample_tolerance, relative_per_sample_tolerance, false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; +} + // Test Gelu -> FastGelu TEST_F(GraphTransformationTests, GeluApproximation_Gelu) { auto model_uri = MODEL_FOLDER "approximation/gelu.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion_format_2.onnx b/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion_format_2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d64e2434d36d6184afcf3b8873fd771b591bbb43 GIT binary patch literal 505 zcmZ8e%TB{E5R7pmCkqnAJ%Cbrz~TTs5#ZEIWrv42@B_VAKw2pxBvM=V+@IpVkoW=C zPTW@9Ml-Xsvv$nJKNN;QrG5b>b>7s=6(=%{zD%X1ps%K~va+)BvT9}3$^#$@h~)8G zQQYXptkTdTL1j_E6>MopigK+#+y-W$wNeCvK9p;?-gnyVozp@WiU31xp>y!_oXr;V zHVA>BkBd15c@h#V(&gA1*|uHQwz~k<#k=KB=Oj-L5{iSqBYPjzoJ>RZ=u_&rg(t6MS%0wnEPcc3RRZ`XEu>3)O{VGt933XDgl{?wTV hw?N~~*_6AF;fIi*D{|bs0FwWoyYqAR&Z7=`M}Htfev1GA literal 0 HcmV?d00001