diff --git a/onnxruntime/core/optimizer/fast_gelu_fusion.cc b/onnxruntime/core/optimizer/fast_gelu_fusion.cc index 57a4e5f86d..6c622e096f 100644 --- a/onnxruntime/core/optimizer/fast_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/fast_gelu_fusion.cc @@ -141,6 +141,25 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, } nodes_to_fuse.push_back(add1_node); + // check if pow node has Cast parent, expect Add has same Cast parent as well + const Node* p_cast1_node = graph_utils::FirstParentByType(pow1_node, "Cast"); + if (p_cast1_node != nullptr) { + Node& cast1_node = *graph.GetNode(p_cast1_node->Index()); + // this is fused Cast node, so expect 2 output edges + if (!CheckNode(graph, cast1_node, "Cast", {9, 13}, pow1_node.GetExecutionProviderType(), false) || + cast1_node.GetOutputEdgesCount() != 2){ + return matchResult; + } + const Node* p_pow_node = graph_utils::FirstChildByType(cast1_node, "Pow"); + if (p_pow_node == nullptr || p_pow_node->Index() != pow1_node.Index()) { + return matchResult; + } + const Node* p_add_node = graph_utils::FirstChildByType(cast1_node, "Add"); + if (p_add_node == nullptr || p_add_node->Index() != add1_node.Index()) { + return matchResult; + } + } + Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index()); input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]); if (!CheckNode(graph, mul2_node, "Mul", {7, 13}, pow1_node.GetExecutionProviderType(), true) || @@ -156,6 +175,22 @@ MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, return matchResult; } +/** +In case of ORTModule, there are extra Cast nodes exported for fp16. They should be fused into two nodes: + +x --> Cast --> FastGelu + +The first Cast should have been fused in CommonSubexpressionElimination transformer, thus it has 2 output edges. + ++--------------------------------------------> Mul ---> Cast ----+ +| | +| v +X --> Cast --> Pow --> Mul --> Add --> Mul --> Tanh --> Add --> Mul + | ^ + | | + +------------------------+ + +*/ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -169,12 +204,14 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); std::vector> nodes_to_fuse; + bool second_formula = false; MatchResult matchRet = CheckFirstFormula(graph, node, nodes_to_fuse); if (!matchRet.matched) { nodes_to_fuse.clear(); matchRet = CheckSecondFormula(graph, node, nodes_to_fuse); if (!matchRet.matched) continue; + second_formula = true; }; Node& tanh_node = *graph.GetNode(matchRet.tanh_input_node->OutputNodesBegin()->Index()); @@ -201,6 +238,30 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, input_index = optimizer_utils::IndexOfNodeInput(mul5_node, *add2_node.MutableOutputDefs()[0]); const Node* p_mul5_input_node = graph_utils::GetInputNode(mul5_node, (input_index + 1) % 2); if (p_mul5_input_node == nullptr) continue; + + // if this is second formula and if pow node has Cast parent, expect mul5_node has Cast parent as well + NodeArg* cast_input_arg = nullptr; + if (second_formula) { + const Node* p_cast1_node = graph_utils::FirstParentByType(node, "Cast"); + if (p_cast1_node != nullptr) { + // we've done the node check in second formula for pow node + Node& cast1_node = *graph.GetNode(p_cast1_node->Index()); + cast_input_arg = cast1_node.MutableInputDefs()[0]; + + const Node* p_cast3_node = graph_utils::FirstParentByType(mul5_node, "Cast"); + if (p_cast3_node == nullptr) continue; + + Node& cast3_node = *graph.GetNode(p_cast3_node->Index()); + if (!CheckNode(graph, cast3_node, "Cast", {9, 13}, node.GetExecutionProviderType(), true)) { + continue; + } + // overwrite and continue as usual + p_mul5_input_node = graph_utils::FirstParentByType(cast3_node, "Mul"); + nodes_to_fuse.push_back(cast3_node); + // keep cast1_node for reuse, its output edges will be adjusted in FinalizeNodeFusion() + } + } + Node& mul6_node = const_cast(*p_mul5_input_node); if (!CheckNode(graph, mul6_node, "Mul", {7, 13}, node.GetExecutionProviderType(), false)) { continue; @@ -214,8 +275,15 @@ Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, } } - if (input_index == -1 || mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != matchRet.gelu_without_bias_input_arg->Name()) - continue; + if (input_index == -1) continue; + // check same parent for both mul6 and pow, with or without cast + if (cast_input_arg != nullptr) { + if (mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != cast_input_arg->Name()) + continue; + } else { + if (mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != matchRet.gelu_without_bias_input_arg->Name()) + continue; + } std::vector gelu_input_defs{matchRet.gelu_without_bias_input_arg}; nodes_to_fuse.insert(nodes_to_fuse.end(), {tanh_node, add2_node, mul6_node, mul5_node}); diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 64cdc2ba2a..aaa9bcbe07 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -206,7 +206,7 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, if (p_cast_node != nullptr) { Node& cast_node = *graph.GetNode(p_cast_node->Index()); if (!graph_utils::IsSupportedOptypeVersionAndDomain(cast_node, "Cast", {9, 13}) || - cast_node.GetExecutionProviderType() != cast_node.GetExecutionProviderType() || + cast_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() || !optimizer_utils::CheckOutputEdges(graph, cast_node, 1)) { continue; } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 4f70f1a000..2568839cce 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -19,6 +19,7 @@ #include "core/optimizer/bias_softmax_fusion.h" #include "core/optimizer/computation_reduction.h" #include "core/optimizer/cast_elimination.h" +#include "core/optimizer/common_subexpression_elimination.h" #include "core/optimizer/concat_slice_elimination.h" #include "core/optimizer/constant_folding.h" #include "core/optimizer/conv_activation_fusion.h" @@ -2479,6 +2480,34 @@ TEST_F(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest2) { ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1); } +TEST_F(GraphTransformationTests, FastGeluFusionWithCastsTest3) { + auto model_uri = MODEL_FOLDER "fusion/fast_gelu3_with_casts.onnx"; + std::shared_ptr p_model; + auto load_ret = Model::Load(model_uri, p_model, nullptr, *logger_); + ASSERT_TRUE(load_ret.IsOK()); + Graph& graph = p_model->MainGraph(); + + // ORTModule for gpt2 model has two casts fused into one before FastGeluFusion + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Cast"] == 2); + + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + ASSERT_TRUE(ret.IsOK()); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Tanh"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["Cast"] == 1); + ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1); +} + + struct BiasSoftmaxFusionTester { std::shared_ptr p_model_; Status model_load_; diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.onnx b/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.onnx new file mode 100644 index 0000000000..46275f8b5e Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py b/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py new file mode 100644 index 0000000000..30cd332385 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py @@ -0,0 +1,92 @@ +import onnx +from onnx import helper +from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto +from onnx import numpy_helper +import numpy as np + +# Gelu formula: x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))) + +X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64]) +Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64]) + +pow_np_vals = np.asarray([3]).astype(np.float32).reshape(()) +pow_initializer = numpy_helper.from_array(pow_np_vals, "pow_init") + +a_weight_np_vals = np.asarray([0.044714998453855515]).astype(np.float32).reshape(()) +a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "mul1_init") + +b_weight_np_vals = np.asarray([0.7978845834732056]).astype(np.float32).reshape(()) +b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "mul2_init") + +c_weight_np_vals = np.asarray([0.5]).astype(np.float32).reshape(()) +c_weight_initializer = numpy_helper.from_array(c_weight_np_vals, "mul3_init") + +b_bias_np_vals = np.asarray([1.0]).astype(np.float32).reshape(()) +b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "add2_init") + +nodes = [] +gelu_input = "input" +leading_identity = helper.make_node('Identity', [gelu_input], ['identity_leading'], name="identity_leading") +gelu_input = "identity_leading" +nodes.append(leading_identity) + +mul_input_name = gelu_input + +cast1 = helper.make_node('Cast', [mul_input_name], ['cast1'], name='cast1', to=1) +nodes.append(cast1) + +pow1 = helper.make_node('Pow', ['cast1', pow_initializer.name], ['pow1'], name="pow1") +nodes.append(pow1) + +mul1 = helper.make_node('Mul', ['pow1', a_weight_initializer.name], ['mul1'], name="mul1") +nodes.append(mul1) + +cast2 = helper.make_node('Cast', [mul_input_name], ['cast2'], name='cast2', to=1) +nodes.append(cast2) + +add1 = helper.make_node('Add', ['mul1', 'cast2'], ['add1'], name="add1") +nodes.append(add1) + +mul2 = helper.make_node('Mul', ['add1', b_weight_initializer.name], ['mul2'], name="mul2") +nodes.append(mul2) + +tanh = helper.make_node('Tanh', ['mul2'], ['tanh'], name="tanh") +nodes.append(tanh) + +add2 = helper.make_node('Add', ['tanh', b_bias_initializer.name], ['add2'], name="add2") +nodes.append(add2) + +mul5 = helper.make_node('Mul', [mul_input_name, c_weight_initializer.name], ['mul5'], name="mul5") +nodes.append(mul5) + +cast3 = helper.make_node('Cast', ['mul5'], ['cast3'], name='cast3', to=1) +nodes.append(cast3) + +mul6 = helper.make_node('Mul', ['cast3', 'add2'], ['mul6'], name="mul6") +ending_identity = helper.make_node('Identity', ['mul6'], ['output'], name="ending_identity") +nodes.extend([mul6, ending_identity]) + +initializers = [] + +initializers.extend( + [pow_initializer, a_weight_initializer, b_weight_initializer, b_bias_initializer, c_weight_initializer]) +# Create the graph (GraphProto) +graph_def = helper.make_graph(nodes, 'test-model', [X], [Y], initializers) + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 13 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = "com.microsoft" + +opsets.append(msdomain) +kwargs = {} +kwargs["opset_imports"] = opsets + +model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) + +onnx.save(model_def, "fast_gelu3_with_casts.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py b/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py index b596422598..b814294626 100644 --- a/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py +++ b/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py @@ -7,7 +7,7 @@ from enum import Enum def GenerateModel(model_name): - nodes = [ # SimplifiedLayerNorm subgraph + nodes = [ # LayerNormWithCast2 subgraph helper.make_node("ReduceMean", ["A"], ["rd1_out"], "reduce", axes=[-1]), helper.make_node("Sub", ["A", "rd1_out"], ["sub1_out"], "sub"), helper.make_node("Cast", ["pow_in_2"], ["cast_out"], "cast", to=10), @@ -29,7 +29,7 @@ def GenerateModel(model_name): graph = helper.make_graph( nodes, - "SimplifiedLayerNorm", #name + "LayerNormWithCast2", #name [ # inputs helper.make_tensor_value_info('A', TensorProto.FLOAT16, [16, 32, 4]), ],