diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 0c153350cf..9bd072d5fa 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -108,7 +108,7 @@ class Node { /** Gets the domain of the OperatorSet that specifies the operator returned by #OpType. */ const std::string& Domain() const noexcept { return domain_; } - /** Gets the Node's exection priority. + /** Gets the Node's execution priority. @remarks Lower value means higher priority */ int Priority() const noexcept { return priority_; }; diff --git a/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion_recompute.onnx b/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion_recompute.onnx new file mode 100644 index 0000000000..0e083e2256 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/bias_gelu_fusion_recompute.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/bias_gelu_matmul_gen.py b/onnxruntime/test/testdata/transform/fusion/bias_gelu_matmul_gen.py new file mode 100644 index 0000000000..9d78e622f9 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/bias_gelu_matmul_gen.py @@ -0,0 +1,36 @@ +import onnx +from onnx import helper +from onnx import TensorProto + +graph = helper.make_graph( + [ # nodes + # Add node before Gelu + helper.make_node("Add", ["A", "B"], ["add0_out"], "add0"), + + # Gelu subgraph + helper.make_node("Div", ["add0_out", "div_const"], ["div_out"], "div"), + helper.make_node("Mul", ["add0_out", "mul_const"], ["mul_out"], "mul0"), + helper.make_node("Erf", ["div_out"], ["erf_out"], "erf"), + helper.make_node("Add", ["erf_out", "add_const"], ["add1_out"], "add1"), + helper.make_node("Mul", ["mul_out", "add1_out"], ["C"], "mul1"), + + # MatMul node after Gelu for recompute + helper.make_node("MatMul", ["X", "C"], ["D"], "matmul"), + ], + "Gelu_Add_Fusion_Recompute", #name + [ # inputs + helper.make_tensor_value_info('A', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]), + helper.make_tensor_value_info('B', TensorProto.FLOAT, [3072]), + helper.make_tensor_value_info('X', TensorProto.FLOAT, ['unk_5', 'unk_6', 3072]), + ], + [ # outputs + helper.make_tensor_value_info('D', TensorProto.FLOAT, ['unk_3', 'unk_4', 'unk_5']), + ], + [ # initializers + helper.make_tensor('div_const', TensorProto.FLOAT, [], [1.4142135381698608]), + helper.make_tensor('mul_const', TensorProto.FLOAT, [], [0.5]), + helper.make_tensor('add_const', TensorProto.FLOAT, [], [1]), + ]) + +model = helper.make_model(graph) +onnx.save(model, r'bias_gelu_fusion_recompute.onnx') diff --git a/orttraining/orttraining/core/optimizer/localized_recompute.cc b/orttraining/orttraining/core/optimizer/localized_recompute.cc index 0c5eb31c40..6d6f774bbe 100644 --- a/orttraining/orttraining/core/optimizer/localized_recompute.cc +++ b/orttraining/orttraining/core/optimizer/localized_recompute.cc @@ -44,7 +44,7 @@ Status GeluRecompute::ApplyImpl(Graph& graph, bool& modified, int /*graph_level* Node& recompute_node = graph.AddNode(node.Name() + "_recompute", node.OpType(), "Recompute of " + node.Name(), - {node.MutableInputDefs()[0]}, + node.MutableInputDefs(), {&recomputed_output}, &node.GetAttributes(), node.Domain()); diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 1ad6963daf..6be077c7f3 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -10,12 +10,15 @@ #include "gtest/gtest.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/utils.h" +#include "core/optimizer/bias_gelu_fusion.h" +#include "core/optimizer/gelu_fusion.h" #include "core/optimizer/dropout_elimination.h" #include "orttraining/core/optimizer/bias_dropout_fusion.h" #include "orttraining/core/optimizer/gist_encode_decode.h" #include "orttraining/core/optimizer/nonzero_shape_setter.h" #include "orttraining/core/optimizer/megatron_transformer.h" #include "orttraining/core/optimizer/concat_replacement.h" +#include "orttraining/core/optimizer/localized_recompute.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/default_providers.h" #include "test/util/include/asserts.h" @@ -417,6 +420,33 @@ TEST_F(GraphTransformationTests, MegatronSelfAttentionPartitionRank1) { } } +TEST_F(GraphTransformationTests, BiasGeluRecomputeTest) { + auto model_uri = MODEL_FOLDER "fusion/bias_gelu_fusion_recompute.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Div"] == 0); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Erf"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.Gelu"] == 0); + ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 2); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "com.microsoft.BiasGelu") { + ASSERT_TRUE(node.InputDefs().size() == 2); + } + } +} + // We only tested on CUDA run. #if defined(USE_CUDA) TEST_F(GraphTransformationTests, MegatronMLPPartitionCorrectnessTest) {