diff --git a/onnxruntime/core/optimizer/bias_gelu_fusion.cc b/onnxruntime/core/optimizer/bias_gelu_fusion.cc index 370e8b144b..5bdad45b9c 100644 --- a/onnxruntime/core/optimizer/bias_gelu_fusion.cc +++ b/onnxruntime/core/optimizer/bias_gelu_fusion.cc @@ -64,20 +64,28 @@ Status BiasGelu::ApplyImpl(Graph& graph, bool& modified, int graph_level, const } const Node& next_node = (*next_node_itr); - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) || + if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "FastGelu", {1}, kMSDomain)) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue; } + bool is_fast_gelu = next_node.OpType().compare("FastGelu") == 0; + if (is_fast_gelu && next_node.InputDefs().size() > 1) { + continue; + } + if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { continue; } Node& add_node = node; Node& gelu_node = const_cast(next_node); + std::string op_type = "BiasGelu"; + if (is_fast_gelu) op_type = "FastGelu"; - Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName("BiasGelu"), - "BiasGelu", + Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type), + op_type, "fused Add and Gelu", gelu_input, {}, diff --git a/onnxruntime/core/optimizer/fast_gelu_fusion.cc b/onnxruntime/core/optimizer/fast_gelu_fusion.cc new file mode 100644 index 0000000000..d0394a0da5 --- /dev/null +++ b/onnxruntime/core/optimizer/fast_gelu_fusion.cc @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/initializer.h" +#include "core/optimizer/fast_gelu_fusion.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph_utils.h" +#include "float.h" +#include + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; +namespace onnxruntime { + +// FastGelu supports limited data types. +static std::vector supported_data_types{"tensor(float16)", "tensor(float)"}; + +static bool CheckNode(const Node& node, const std::string& op_name, int32_t opset_version, ProviderType provider, + bool require_single_output=false){ + return graph_utils::IsSupportedOptypeVersionAndDomain(node, op_name, {opset_version}) && + node.GetExecutionProviderType() == provider && + optimizer_utils::IsSupportedDataType(node, supported_data_types) && + (!require_single_output || node.GetOutputEdgesCount() == 1); +} + +MatchResult FastGeluFusion::CheckFirstFormula(Graph& graph, Node& mul1_node, + std::vector>& nodes_to_fuse) const { + MatchResult matchResult{false, nullptr, nullptr}; + if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul1_node, "Mul", {7}) || + !graph_utils::IsSupportedProvider(mul1_node, GetCompatibleExecutionProviders()) || + mul1_node.GetOutputEdgesCount() != 1 || + !optimizer_utils::IsSupportedDataType(mul1_node, supported_data_types)) { + return matchResult; + } + + int32_t input_index = -1; + const float mul_val = 0.044715f; + for (auto i = 0; i < 2; i++) { + if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[i]), mul_val, true)){ + input_index = i; + break; + } + } + + if (input_index == -1) return matchResult; + + NodeArg* gelu_without_bias_input_arg = mul1_node.MutableInputDefs()[(input_index + 1) % 2]; + nodes_to_fuse.push_back(mul1_node); + + + Node& mul2_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index()); + input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *mul1_node.MutableOutputDefs()[0]); + if (!CheckNode(mul2_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true) || + mul2_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name()) { + return matchResult;; + } + nodes_to_fuse.push_back(mul2_node); + + + Node& add1_node = *graph.GetNode(mul2_node.OutputNodesBegin()->Index()); + input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul2_node.MutableOutputDefs()[0]); + if (!CheckNode(add1_node, "Add", 7, mul1_node.GetExecutionProviderType(), true) || + !optimizer_utils::IsInitializerWithExpectedValue(graph, *(add1_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) { + return matchResult; + } + nodes_to_fuse.push_back(add1_node); + + + Node& mul3_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index()); + if (!CheckNode(mul3_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) { + return matchResult; + } + nodes_to_fuse.push_back(mul3_node); + + input_index = optimizer_utils::IndexOfNodeInput(mul3_node, *add1_node.MutableOutputDefs()[0]); + const Node* p_mul3_input_node = graph_utils::GetInputNode(mul3_node, (input_index + 1) % 2); + if (p_mul3_input_node == nullptr) return matchResult; + Node& mul4_node = const_cast(*p_mul3_input_node); + if (!CheckNode(mul4_node, "Mul", 7, mul1_node.GetExecutionProviderType(), true)) { + return matchResult; + } + + input_index = -1; + const float mul4_val = 0.7978845834732056f; + for (auto i = 0; i < 2; i++) { + if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul4_node.InputDefs()[i]), mul4_val, true)){ + input_index = i; + break; + } + } + + if (input_index == -1 || mul4_node.InputDefs()[(input_index + 1) % 2]->Name() != gelu_without_bias_input_arg->Name()) + return matchResult; + nodes_to_fuse.push_back(mul4_node); + + matchResult.matched = true; + matchResult.gelu_without_bias_input_arg = gelu_without_bias_input_arg; + matchResult.tanh_input_node = &mul3_node; + return matchResult; +} + +MatchResult FastGeluFusion::CheckSecondFormula(Graph& graph, Node& pow1_node, + std::vector>& nodes_to_fuse) const { + MatchResult matchResult{false, nullptr, nullptr}; + if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow1_node, "Pow", {7}) || + !graph_utils::IsSupportedProvider(pow1_node, GetCompatibleExecutionProviders()) || + pow1_node.GetOutputEdgesCount() != 1 || + !optimizer_utils::IsSupportedDataType(pow1_node, supported_data_types)) { + return matchResult; + } + + if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(pow1_node.InputDefs()[1]), 3.0f, true)){ + return matchResult; + } + + NodeArg* pow_input_arg = pow1_node.MutableInputDefs()[0]; + nodes_to_fuse.push_back(pow1_node); + + Node& mul1_node = *graph.GetNode(pow1_node.OutputNodesBegin()->Index()); + auto input_index = optimizer_utils::IndexOfNodeInput(mul1_node, *pow1_node.MutableOutputDefs()[0]); + if (!CheckNode(mul1_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) || + !optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul1_node.InputDefs()[(input_index + 1) % 2]), + 0.044714998453855515f, true)) { + return matchResult; + } + nodes_to_fuse.push_back(mul1_node); + + + Node& add1_node = *graph.GetNode(mul1_node.OutputNodesBegin()->Index()); + input_index = optimizer_utils::IndexOfNodeInput(add1_node, *mul1_node.MutableOutputDefs()[0]); + if (!CheckNode(add1_node, "Add", 7, pow1_node.GetExecutionProviderType(), true) || + add1_node.MutableInputDefs()[(input_index + 1) % 2]->Name() != pow_input_arg->Name()) { + return matchResult; + } + nodes_to_fuse.push_back(add1_node); + + + Node& mul2_node = *graph.GetNode(add1_node.OutputNodesBegin()->Index()); + input_index = optimizer_utils::IndexOfNodeInput(mul2_node, *add1_node.MutableOutputDefs()[0]); + if (!CheckNode(mul2_node, "Mul", 7, pow1_node.GetExecutionProviderType(), true) || + !optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul2_node.InputDefs()[(input_index + 1) % 2]), + 0.7978845834732056f, true)) { + return matchResult; + } + nodes_to_fuse.push_back(mul2_node); + + matchResult.matched = true; + matchResult.gelu_without_bias_input_arg = pow_input_arg; + matchResult.tanh_input_node = &mul2_node; + return matchResult; +} + +Status FastGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) + continue; + + Node& node = *p_node; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + std::vector> nodes_to_fuse; + MatchResult matchRet = CheckFirstFormula(graph, node, nodes_to_fuse); + if (!matchRet.matched) { + nodes_to_fuse.clear(); + matchRet = CheckSecondFormula(graph, node, nodes_to_fuse); + + if(!matchRet.matched) continue; + }; + + Node& tanh_node = *graph.GetNode(matchRet.tanh_input_node->OutputNodesBegin()->Index()); + if (!CheckNode(tanh_node, "Tanh", 6, node.GetExecutionProviderType(), true)) { + continue; + } + + + Node& add2_node = *graph.GetNode(tanh_node.OutputNodesBegin()->Index()); + if (!CheckNode(add2_node, "Add", 7, node.GetExecutionProviderType(), true)) { + continue; + } + + auto input_index = optimizer_utils::IndexOfNodeInput(add2_node, *tanh_node.MutableOutputDefs()[0]); + if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(add2_node.InputDefs()[(input_index + 1) % 2]), 1.0f, true)) { + continue; + } + + Node& mul5_node = *graph.GetNode(add2_node.OutputNodesBegin()->Index()); + // This is the output of the Gelu subgraph, we don't need check it has single edge. + if (!CheckNode(mul5_node, "Mul", 7, node.GetExecutionProviderType(), false)) { + continue; + } + + // ingnore the transformer if Gelu's output is the graph's output. + if (!graph.GetNodeOutputsInGraphOutputs(mul5_node).empty()) { + continue; + } + + input_index = optimizer_utils::IndexOfNodeInput(mul5_node, *add2_node.MutableOutputDefs()[0]); + const Node* p_mul5_input_node = graph_utils::GetInputNode(mul5_node, (input_index + 1) % 2); + if (p_mul5_input_node == nullptr) continue; + Node& mul6_node = const_cast(*p_mul5_input_node); + if (!CheckNode(mul6_node, "Mul", 7, node.GetExecutionProviderType(), false)) { + continue; + } + + input_index = -1; + for (auto i = 0; i < 2; i++) { + if (optimizer_utils::IsInitializerWithExpectedValue(graph, *(mul6_node.InputDefs()[i]), 0.5f, true)){ + input_index = i; + break; + } + } + + if (input_index == -1 || mul6_node.InputDefs()[(input_index + 1) % 2]->Name() != matchRet.gelu_without_bias_input_arg->Name()) + continue; + + std::vector gelu_input_defs{matchRet.gelu_without_bias_input_arg}; + nodes_to_fuse.insert(nodes_to_fuse.end(), {tanh_node, add2_node, mul6_node, mul5_node}); + + auto type_info = *node.MutableOutputDefs()[0]->TypeAsProto(); + auto& shape_output = graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("fast_gelu_output"), &type_info); + Node& fast_gelu_node = graph.AddNode(graph.GenerateNodeName("GPT2Gelu"), + "FastGelu", + "fused GPT2Gelu subgraphs ", + gelu_input_defs, + {&shape_output}, {}, kMSDomain); + + // assign provider to this new node, provider should be same as the provider for old node. + fast_gelu_node.SetExecutionProviderType(node.GetExecutionProviderType()); + + // move input edges to node (first in list) across to the fast_gelu_node. + // move output definitions and output edges from mul5_node (last in list) to fast_gelu_node. + // remove all nodes. + graph_utils::FinalizeNodeFusion(graph, nodes_to_fuse, fast_gelu_node); + + modified = true; + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/fast_gelu_fusion.h b/onnxruntime/core/optimizer/fast_gelu_fusion.h new file mode 100644 index 0000000000..e2d70c18a1 --- /dev/null +++ b/onnxruntime/core/optimizer/fast_gelu_fusion.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +struct MatchResult { + public: + bool matched; + NodeArg* gelu_without_bias_input_arg; // The Gelu input arg if not considering bias node. + Node* tanh_input_node; +}; + +/** +@Class FastGeluFusion + +Rewrite graph fusing Gelu activation subgraph to a single Gelu node. + +The formula corresponding to Gelu activation subgraph: +x * 0.5 * (1.0 + tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) or +x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))), where x is the input. + +*/ +class FastGeluFusion : public GraphTransformer { + public: + FastGeluFusion(const std::unordered_set& compatible_execution_providers = {}) noexcept + : GraphTransformer("FastGeluFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + MatchResult CheckFirstFormula(Graph& graph, Node& node, std::vector>& nodes_to_fuse) const; + + MatchResult CheckSecondFormula(Graph& graph, Node& nodes, std::vector>& nodes_to_fuse) const; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 31e57907cc..fe972c1ecf 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -21,6 +21,7 @@ #include "core/optimizer/bias_gelu_fusion.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gelu_approximation.h" +#include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/embed_layer_norm_fusion.h" @@ -135,6 +136,7 @@ std::vector> GenerateTransformers(TransformerL std::unordered_set cuda_execution_providers = {onnxruntime::kCudaExecutionProvider}; transformers.emplace_back(onnxruntime::make_unique(cuda_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cuda_execution_providers)); #endif } break; diff --git a/onnxruntime/core/optimizer/utils.cc b/onnxruntime/core/optimizer/utils.cc index c58bac0498..9af9fa4269 100644 --- a/onnxruntime/core/optimizer/utils.cc +++ b/onnxruntime/core/optimizer/utils.cc @@ -36,6 +36,8 @@ bool IsInitializerWithExpectedValue(const Graph& graph, const NodeArg& input_arg return false; } + const float atol = 1e-8f; + const float rtol = 1e-5f; const ONNX_NAMESPACE::TensorProto* tensor_proto = nullptr; if (is_constant) { tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); @@ -51,20 +53,28 @@ bool IsInitializerWithExpectedValue(const Graph& graph, const NodeArg& input_arg const auto data_type = tensor_proto->data_type(); if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { const float* val = init_const->data(); - float diff = std::abs(val[0] - static_cast(expected_value)); - if (diff > FLT_EPSILON) { + if (std::isnan(val[0]) || std::isinf(val[0])) return false; + + float diff = std::abs(val[0] - expected_value); + if (diff > (atol + rtol * std::abs(expected_value))) { return false; } } else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) { const double* val = init_const->data(); - double diff = std::abs(val[0] - static_cast(expected_value)); - if (diff > DBL_EPSILON) { + if (std::isnan(val[0]) || std::isinf(val[0])) return false; + + const double expected_val = static_cast(expected_value); + double diff = std::abs(val[0] - expected_val); + if (diff > (atol + rtol * std::abs(expected_value))) { return false; } } else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { const MLFloat16* val = init_const->data(); - float diff = std::abs(math::halfToFloat(val[0].val) - math::halfToFloat(math::floatToHalf(expected_value))); - if (diff > FLT_EPSILON) { + const float flt_val = math::halfToFloat(val[0].val); + if (std::isnan(flt_val) || std::isinf(flt_val)) return false; + const float expected_val = math::halfToFloat(math::floatToHalf(expected_value)); + float diff = std::abs(flt_val - expected_val); + if (diff > (atol + rtol * std::abs(expected_value))) { return false; } } else { @@ -176,5 +186,27 @@ bool IsShapeKnownOnAllDims(const NodeArg& node_arg, int expected_dim_size) { return true; } +int32_t IndexOfNodeInput(const Node& node, const NodeArg& node_arg) { + int32_t index = 0; + for (auto& input_arg : node.InputDefs()) { + if (input_arg->Name().compare(node_arg.Name()) == 0) { + return index; + } + index++; + } + + return -1; +} + +bool IsSupportedDataType(const Node& node, const std::vector& supported_data_types) { + for (const auto& input_arg : node.InputDefs()) { + if (std::find(supported_data_types.begin(), supported_data_types.end(), + *(input_arg->Type())) == supported_data_types.end()) { + return false; + } + } + return true; +} + } // namespace optimizer_utils } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/utils.h b/onnxruntime/core/optimizer/utils.h index 0fd434993f..89d56a81f0 100644 --- a/onnxruntime/core/optimizer/utils.h +++ b/onnxruntime/core/optimizer/utils.h @@ -52,5 +52,15 @@ bool ValidateShape(const NodeArg& node_arg, const std::initializer_list */ bool IsShapeKnownOnAllDims(const NodeArg& node_arg, int expected_dim_size); +/** Get the index of node_arg among the node's all inputs. +@remarks -1 when node_arg is not in node's inputs.. +*/ +int32_t IndexOfNodeInput(const Node& node, const NodeArg& node_arg); + +/** Check whether node's input data types are in supported data type list. +@param supported_data_types specify the supported data types. +*/ +bool IsSupportedDataType(const Node& node, const std::vector& supported_data_types); + } // namespace optimizer_utils } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index f4bc9ce8f9..242f18dc52 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -33,6 +33,7 @@ #include "core/optimizer/unsqueeze_elimination.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/attention_fusion.h" +#include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/utils.h" #include "core/platform/env.h" #include "core/util/math.h" @@ -1201,6 +1202,163 @@ TEST(GraphTransformationTests, GeluApproximation_Gelu_Add_MatMul) { EXPECT_EQ(op_to_count["FastGelu"], 1); } +TEST(GraphTransformationTests, FastGeluFusionTest) { + auto model_uri = MODEL_FOLDER "fusion/fast_gelu.onnx"; + std::shared_ptr p_model; + auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(load_ret.IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Identity"] == 2); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Tanh"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["FastGelu"] == 1); +} + +TEST(GraphTransformationTests, FastGeluUseGraphInputFusionTest) { + auto model_uri = MODEL_FOLDER "fusion/fast_gelu_use_graph_input.onnx"; + std::shared_ptr p_model; + auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(load_ret.IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Tanh"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["FastGelu"] == 1); +} + +TEST(GraphTransformationTests, FastGeluWithBiasFusionTest) { + auto model_uri = MODEL_FOLDER "fusion/fast_gelu_with_bias.onnx"; + std::shared_ptr p_model; + auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(load_ret.IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Tanh"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["FastGelu"] == 1); +} + +TEST(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest) { + auto model_uri = MODEL_FOLDER "fusion/fast_gelu_with_bias_use_graph_input.onnx"; + std::shared_ptr p_model; + auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(load_ret.IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Tanh"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["FastGelu"] == 1); +} + +TEST(GraphTransformationTests, FastGeluFusionTest2) { + auto model_uri = MODEL_FOLDER "fusion/fast_gelu2.onnx"; + std::shared_ptr p_model; + auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(load_ret.IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Tanh"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["FastGelu"] == 1); +} + +TEST(GraphTransformationTests, FastGeluUseGraphInputFusionTest2) { + auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_use_graph_input.onnx"; + std::shared_ptr p_model; + auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(load_ret.IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Tanh"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["FastGelu"] == 1); +} + +TEST(GraphTransformationTests, FastGeluWithBiasFusionTest2) { + auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_with_bias.onnx"; + std::shared_ptr p_model; + auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(load_ret.IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Tanh"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["FastGelu"] == 1); +} + +TEST(GraphTransformationTests, FastGeluWithBiasUseGraphInputFusionTest2) { + auto model_uri = MODEL_FOLDER "fusion/fast_gelu2_with_bias_use_graph_input.onnx"; + std::shared_ptr p_model; + auto load_ret = Model::Load(model_uri, p_model, nullptr, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(load_ret.IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, DefaultLoggingManager().DefaultLogger()); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Add"] == 0); + ASSERT_TRUE(op_to_count["Tanh"] == 0); + ASSERT_TRUE(op_to_count["Mul"] == 0); + ASSERT_TRUE(op_to_count["FastGelu"] == 1); +} + TEST(GraphTransformationTests, LayerNormFusionTest) { auto model_uri = MODEL_FOLDER "fusion/layer_norm.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu.onnx b/onnxruntime/test/testdata/transform/fusion/fast_gelu.onnx new file mode 100644 index 0000000000..d94dede7e7 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fast_gelu.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu.py b/onnxruntime/test/testdata/transform/fusion/fast_gelu.py new file mode 100644 index 0000000000..bfd5c785f0 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/fast_gelu.py @@ -0,0 +1,172 @@ +import onnx +from onnx import helper +from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto +from onnx import numpy_helper +import numpy as np + +# Gelu formula: x * 0.5 * (1.0 + tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) + +has_bias = True # change it to True to generate fast_gelu_with_bias.onnx +gelu_use_graph_input = True # change it to False to let Gelu don't have graph inputs as inputs. + +X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64]) +Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64]) + +bias_np_vals = (0.01 * np.arange(64)).astype(np.float32).reshape((64)) +bias_initializer = numpy_helper.from_array(bias_np_vals, "input_bias") + +a_weight_np_vals = np.asarray([0.044714998453855515]).astype(np.float32).reshape(()) +a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "mul1_init") + +b_weight_np_vals = np.asarray([0.7978845834732056]).astype(np.float32).reshape(()) +b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "mul2_init") + +c_weight_np_vals = np.asarray([0.5]).astype(np.float32).reshape(()) +c_weight_initializer = numpy_helper.from_array(c_weight_np_vals, "mul3_init") + +a_bias_np_vals = np.asarray([1.0]).astype(np.float32).reshape(()) +a_bias_initializer = numpy_helper.from_array(a_bias_np_vals, "add1_init") + +b_bias_np_vals = np.asarray([1.0]).astype(np.float32).reshape(()) +b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "add2_init") + +nodes = [] +gelu_input = "input" +if not gelu_use_graph_input: + leading_identity = helper.make_node( + 'Identity', + [gelu_input], + ['identity_leading'], + name="identity_leading" + ) + gelu_input = "identity_leading" + nodes.append(leading_identity) + +mul_input_name = gelu_input +if has_bias: + add0 = helper.make_node( + 'Add', + [gelu_input, bias_initializer.name], + ['add0'], + name="add0" + ) + mul_input_name = "add0" + nodes.append(add0) + + +mul1 = helper.make_node( + 'Mul', + [mul_input_name, a_weight_initializer.name], + ['mul1'], + name="mul1" +) +nodes.append(mul1) + +mul2 = helper.make_node( + 'Mul', + [mul_input_name, 'mul1'], + ['mul2'], + name="mul2" +) +nodes.append(mul2) + +add1 = helper.make_node( + 'Add', + ['mul2', a_bias_initializer.name], + ['add1'], + name="add1" +) +nodes.append(add1) + +mul3 = helper.make_node( + 'Mul', + [mul_input_name, b_weight_initializer.name], + ['mul3'], + name="mul3" +) +nodes.append(mul3) + +mul4 = helper.make_node( + 'Mul', + ['mul3', 'add1'], + ['mul4'], + name="mul4" +) +nodes.append(mul4) + +tanh = helper.make_node( + 'Tanh', + ['mul4'], + ['tanh'], + name="tanh" +) +nodes.append(tanh) + +add2 = helper.make_node( + 'Add', + ['tanh', b_bias_initializer.name], + ['add2'], + name="add2" +) +nodes.append(add2) + +mul5 = helper.make_node( + 'Mul', + [mul_input_name, c_weight_initializer.name], + ['mul5'], + name="mul5" +) +nodes.append(mul5) + +mul6 = helper.make_node( + 'Mul', + ['mul5', 'add2'], + ['mul6'], + name="mul6" +) +ending_identity = helper.make_node( + 'Identity', + ['mul6'], + ['output'], + name="identity_ending" +) +nodes.extend([mul6, ending_identity]) + +initializers = [] +if has_bias: + initializers = [bias_initializer] + +initializers.extend([a_weight_initializer, a_bias_initializer, b_weight_initializer, b_bias_initializer, c_weight_initializer]) +# Create the graph (GraphProto) +graph_def = helper.make_graph( + nodes, + 'test-model', + [X], + [Y], + initializers +) + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 10 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = "com.microsoft" + +opsets.append(msdomain) +kwargs={} +kwargs["opset_imports"] = opsets + +model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) + +file_name = "fast_gelu" +if has_bias: + file_name += "_with_bias" + +if gelu_use_graph_input: + file_name += "_use_graph_input" +onnx.save(model_def, file_name + ".onnx") + diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu2.onnx b/onnxruntime/test/testdata/transform/fusion/fast_gelu2.onnx new file mode 100644 index 0000000000..2239de14bc Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fast_gelu2.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu2.py b/onnxruntime/test/testdata/transform/fusion/fast_gelu2.py new file mode 100644 index 0000000000..648336eeb2 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/fast_gelu2.py @@ -0,0 +1,163 @@ +import onnx +from onnx import helper +from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto +from onnx import numpy_helper +import numpy as np + +# Gelu formula: x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))) +has_bias = False # change it to True to generate fast_gelu_openai_with_bias.onnx +gelu_use_graph_input = True # change it to False to let Gelu don't have graph inputs/outputs as inputs/outputs. + +X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64]) +Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64]) + +bias_np_vals = (0.01 * np.arange(64)).astype(np.float32).reshape((64)) +bias_initializer = numpy_helper.from_array(bias_np_vals, "input_bias") + +pow_np_vals = np.asarray([3]).astype(np.float32).reshape(()) +pow_initializer = numpy_helper.from_array(pow_np_vals, "pow_init") + +a_weight_np_vals = np.asarray([0.044714998453855515]).astype(np.float32).reshape(()) +a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "mul1_init") + +b_weight_np_vals = np.asarray([0.7978845834732056]).astype(np.float32).reshape(()) +b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "mul2_init") + +c_weight_np_vals = np.asarray([0.5]).astype(np.float32).reshape(()) +c_weight_initializer = numpy_helper.from_array(c_weight_np_vals, "mul3_init") + +b_bias_np_vals = np.asarray([1.0]).astype(np.float32).reshape(()) +b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "add2_init") + +nodes = [] +gelu_input = "input" +if not gelu_use_graph_input: + leading_identity = helper.make_node( + 'Identity', + [gelu_input], + ['identity_leading'], + name="identity_leading" + ) + gelu_input = "identity_leading" + nodes.append(leading_identity) + +mul_input_name = gelu_input +if has_bias: + add0 = helper.make_node( + 'Add', + [gelu_input, bias_initializer.name], + ['add0'], + name="add0" + ) + mul_input_name = "add0" + nodes.append(add0) + + +pow1 = helper.make_node( + 'Pow', + [mul_input_name, pow_initializer.name], + ['pow1'], + name="pow1" +) +nodes.append(pow1) + +mul1 = helper.make_node( + 'Mul', + ['pow1', a_weight_initializer.name], + ['mul1'], + name="mul1" +) +nodes.append(mul1) + +add1 = helper.make_node( + 'Add', + [mul_input_name, "mul1"], + ['add1'], + name="add1" +) +nodes.append(add1) + +mul2 = helper.make_node( + 'Mul', + ['add1', b_weight_initializer.name], + ['mul2'], + name="mul2" +) +nodes.append(mul2) + +tanh = helper.make_node( + 'Tanh', + ['mul2'], + ['tanh'], + name="tanh" +) +nodes.append(tanh) + +add2 = helper.make_node( + 'Add', + ['tanh', b_bias_initializer.name], + ['add2'], + name="add2" +) +nodes.append(add2) + +mul5 = helper.make_node( + 'Mul', + [mul_input_name, c_weight_initializer.name], + ['mul5'], + name="mul5" +) +nodes.append(mul5) + +mul6 = helper.make_node( + 'Mul', + ['mul5', 'add2'], + ['mul6'], + name="mul6" +) +ending_identity = helper.make_node( + 'Identity', + ['mul6'], + ['output'], + name="ending_identity" +) +nodes.extend([mul6, ending_identity]) + +initializers = [] +if has_bias: + initializers = [bias_initializer] + +initializers.extend([pow_initializer, a_weight_initializer, b_weight_initializer, b_bias_initializer, c_weight_initializer]) +# Create the graph (GraphProto) +graph_def = helper.make_graph( + nodes, + 'test-model', + [X], + [Y], + initializers +) + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 10 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = "com.microsoft" + +opsets.append(msdomain) +kwargs={} +kwargs["opset_imports"] = opsets + +model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) + +file_name = "fast_gelu2" +if has_bias: + file_name += "_with_bias" + +if gelu_use_graph_input: + file_name += "_use_graph_input" +onnx.save(model_def, file_name + ".onnx") + diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu2_use_graph_input.onnx b/onnxruntime/test/testdata/transform/fusion/fast_gelu2_use_graph_input.onnx new file mode 100644 index 0000000000..e609f74134 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fast_gelu2_use_graph_input.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu2_with_bias.onnx b/onnxruntime/test/testdata/transform/fusion/fast_gelu2_with_bias.onnx new file mode 100644 index 0000000000..332435ebc3 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fast_gelu2_with_bias.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu2_with_bias_use_graph_input.onnx b/onnxruntime/test/testdata/transform/fusion/fast_gelu2_with_bias_use_graph_input.onnx new file mode 100644 index 0000000000..a83905e34a Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fast_gelu2_with_bias_use_graph_input.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu_use_graph_input.onnx b/onnxruntime/test/testdata/transform/fusion/fast_gelu_use_graph_input.onnx new file mode 100644 index 0000000000..51b018fb24 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fast_gelu_use_graph_input.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu_with_bias.onnx b/onnxruntime/test/testdata/transform/fusion/fast_gelu_with_bias.onnx new file mode 100644 index 0000000000..a7e212ac95 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fast_gelu_with_bias.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu_with_bias_use_graph_input.onnx b/onnxruntime/test/testdata/transform/fusion/fast_gelu_with_bias_use_graph_input.onnx new file mode 100644 index 0000000000..534b084559 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fast_gelu_with_bias_use_graph_input.onnx differ