From 2bda3fd341ac0929653834d3565ba01d02342b7d Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 10 Nov 2022 13:03:30 +0800 Subject: [PATCH] Gather to Slice Fusion (#13599) This PR is to optimize the running for below code from Huggingface's XLNet model. ``` x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long)) ``` The code will be exported to Range->Gather, which can be fused to a Slice Op. Slice kernel is much faster than Gather, especially for backward run. The main reason is for Gather, the data in indices can be duplicated so that it needs sum during backward, but Slice node cannot have such case. Use Huggingface's XLNet model for profiling. - Before the fuse forward, ~753us ![image](https://user-images.githubusercontent.com/11661208/200758439-63f2f9b5-9610-4df8-98c8-a1ad4dc62f4e.png) backward, ~46101us ![image](https://user-images.githubusercontent.com/11661208/200758530-fe16a8ec-ea8f-4b79-b3ac-386b72ba1670.png) - After the fuse forward, ~627us ![image](https://user-images.githubusercontent.com/11661208/200758654-ab9a6068-c45d-40f4-9c71-3862a56732f8.png) backward, ~677us ![image](https://user-images.githubusercontent.com/11661208/200758833-aab1b8e1-1b5d-4e55-88cf-03c2a1d9d42b.png) --- ...er_to_split_fusion.cc => gather_fusion.cc} | 117 +++++++++++++++++- ...ther_to_split_fusion.h => gather_fusion.h} | 13 ++ .../core/optimizer/graph_transformer_utils.cc | 3 +- .../test/optimizer/graph_transform_test.cc | 86 ++++++++++++- .../core/optimizer/graph_transformer_utils.cc | 3 +- 5 files changed, 218 insertions(+), 4 deletions(-) rename onnxruntime/core/optimizer/{gather_to_split_fusion.cc => gather_fusion.cc} (56%) rename onnxruntime/core/optimizer/{gather_to_split_fusion.h => gather_fusion.h} (64%) diff --git a/onnxruntime/core/optimizer/gather_to_split_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc similarity index 56% rename from onnxruntime/core/optimizer/gather_to_split_fusion.cc rename to onnxruntime/core/optimizer/gather_fusion.cc index e0e5b1b2ef..272eb0c252 100644 --- a/onnxruntime/core/optimizer/gather_to_split_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/gather_to_split_fusion.h" +#include "core/optimizer/gather_fusion.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" @@ -175,4 +175,119 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le return Status::OK(); } + +/* +Fuse Range->Gather to Slice. Slice kernel is faster than Gather kernel in this case, +and SliceGrad is much faster than GatherGrad. +*/ +Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion + Node& node = *p_node; + + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Range", {1, 11}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) { + continue; + } + + Node& gather_node = *graph.GetNode(node.OutputNodesBegin()->Index()); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(gather_node, "Gather", {1, 11, 13}) || + !graph_utils::IsSupportedProvider(gather_node, GetCompatibleExecutionProviders())) { + continue; + } + + // Range's output is Gather's input[1]. + if (node.MutableOutputDefs()[0] != gather_node.MutableInputDefs()[1]) { + continue; + } + + InlinedVector> nodes_to_fuse{node, gather_node}; + + auto& range_input_defs = node.MutableInputDefs(); + ORT_ENFORCE(range_input_defs.size() == 3); + // Range's inputs are scalar, need unsqueeze to 1-D tensors. + ONNX_NAMESPACE::TypeProto unsqueeze_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( + node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()); + unsqueeze_output_type.mutable_tensor_type()->set_elem_type(element_type); + unsqueeze_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); + + InlinedVector unsqueeze_outputs; + for (size_t i = 0; i < range_input_defs.size(); ++i) { + unsqueeze_outputs.emplace_back(&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("unsqueeze_output_" + std::to_string(i)), &unsqueeze_output_type)); + } + + // Unsqueeze before and after OpSet-13 have different schemas. + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + if (onnx_opset_version < 13) { + for (size_t i = 0; i < range_input_defs.size(); ++i) { + Node& unsqueeze_node = + graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze", + "Unsqueeze for Fused Gather nodes", {range_input_defs[i]}, {unsqueeze_outputs[i]}); + unsqueeze_node.AddAttribute("axes", std::vector{static_cast(0)}); + unsqueeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + } + } else { + ONNX_NAMESPACE::TensorProto unsqueeze_axes_initializer_proto; + unsqueeze_axes_initializer_proto.set_name(graph.GenerateNodeName("UnsqueezeAxesInitializer")); + unsqueeze_axes_initializer_proto.add_dims(static_cast(1)); + unsqueeze_axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + unsqueeze_axes_initializer_proto.add_int64_data(static_cast(0)); + NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializer(graph, unsqueeze_axes_initializer_proto); + + for (size_t i = 0; i < range_input_defs.size(); ++i) { + Node& unsqueeze_node = graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze", + "Unsqueeze for Fused Gather nodes", + {range_input_defs[i], unsqueeze_axes_arg}, {unsqueeze_outputs[i]}); + unsqueeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + } + } + + int64_t axis = 0; // Default value. + auto& attrs = gather_node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + } + + ONNX_NAMESPACE::TensorProto slice_axes_initializer_proto; + slice_axes_initializer_proto.set_name(graph.GenerateNodeName("SliceAxesInitializer")); + slice_axes_initializer_proto.add_dims(static_cast(1)); + slice_axes_initializer_proto.set_data_type(element_type); + // Tind of Slice can only support int32 and int64. + if (element_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { + slice_axes_initializer_proto.add_int64_data(axis); + } else { + slice_axes_initializer_proto.add_int32_data(static_cast(axis)); + } + NodeArg* slice_axes_arg = &graph_utils::AddInitializer(graph, slice_axes_initializer_proto); + Node& slice_node = graph.AddNode(graph.GenerateNodeName("Slice"), "Slice", "Slice for Fused Gather nodes", + {gather_node.MutableInputDefs()[0], unsqueeze_outputs[0], unsqueeze_outputs[1], + slice_axes_arg, unsqueeze_outputs[2]}, + {gather_node.MutableOutputDefs()[0]}); + slice_node.SetExecutionProviderType(gather_node.GetExecutionProviderType()); + + for (Node& n : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, n); + graph.RemoveNode(n.Index()); + } + + modified = true; + } + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gather_to_split_fusion.h b/onnxruntime/core/optimizer/gather_fusion.h similarity index 64% rename from onnxruntime/core/optimizer/gather_to_split_fusion.h rename to onnxruntime/core/optimizer/gather_fusion.h index 4860393b47..bdb8c1be82 100644 --- a/onnxruntime/core/optimizer/gather_to_split_fusion.h +++ b/onnxruntime/core/optimizer/gather_fusion.h @@ -23,4 +23,17 @@ class GatherToSplitFusion : public GraphTransformer { bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const; }; +/** +@Class GatherToGliceFusion + +Fuse Range->Gather to Slice node. +*/ +class GatherToSliceFusion : public GraphTransformer { + public: + GatherToSliceFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherToSliceFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 58b64511d4..ec0e637bce 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -39,7 +39,7 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" -#include "core/optimizer/gather_to_split_fusion.h" +#include "core/optimizer/gather_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -268,6 +268,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 70a23dd622..fa4a1c2dec 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -39,7 +39,7 @@ #include "core/optimizer/embed_layer_norm_fusion.h" #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" -#include "core/optimizer/gather_to_split_fusion.h" +#include "core/optimizer/gather_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -6391,5 +6391,89 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { } } +TEST_F(GraphTransformationTests, GatherToSliceFusion) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + ASSERT_EQ(op_count_map["Range"], 1); + ASSERT_EQ(op_count_map["Gather"], 1); + }; + + // OpSet-12, Tind is int32. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{8, 8, 8, 8}}); + auto* range_input_1 = builder.MakeInitializer({}, {0}); + auto* range_input_2 = builder.MakeInitializer({}, {8}); + auto* range_input_3 = builder.MakeInitializer({}, {1}); + auto* range_output = builder.MakeIntermediate(); + auto* gather_output = builder.MakeOutput(); + + builder.AddNode("Range", {range_input_1, range_input_2, range_input_3}, {range_output}); + builder.AddNode("Gather", {data_arg, range_output}, {gather_output}) + .AddAttribute("axis", static_cast(2)); + }; + + auto post_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + ASSERT_EQ(op_count_map["Range"], 0); + ASSERT_EQ(op_count_map["Gather"], 0); + ASSERT_EQ(op_count_map["Slice"], 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Slice") { + const NodeArg& input_arg = *(node.InputDefs()[3]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + ASSERT_TRUE(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + ASSERT_TRUE(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32); + ASSERT_EQ(2, *(init_const.data())); + } + } + }; + + std::unique_ptr transformer = std::make_unique(); + TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } + + // OpSet-14, Tind is int64. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{8, 8, 8, 8}}); + auto* range_input_1 = builder.MakeInitializer({}, {static_cast(0)}); + auto* range_input_2 = builder.MakeInitializer({}, {static_cast(8)}); + auto* range_input_3 = builder.MakeInitializer({}, {static_cast(1)}); + auto* range_output = builder.MakeIntermediate(); + auto* gather_output = builder.MakeOutput(); + + builder.AddNode("Range", {range_input_1, range_input_2, range_input_3}, {range_output}); + builder.AddNode("Gather", {data_arg, range_output}, {gather_output}) + .AddAttribute("axis", static_cast(2)); + }; + + auto post_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + ASSERT_EQ(op_count_map["Range"], 0); + ASSERT_EQ(op_count_map["Gather"], 0); + ASSERT_EQ(op_count_map["Slice"], 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Slice") { + const NodeArg& input_arg = *(node.InputDefs()[3]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + ASSERT_TRUE(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + ASSERT_TRUE(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + ASSERT_EQ(2, static_cast(*(init_const.data()))); + } + } + }; + + std::unique_ptr transformer = std::make_unique(); + TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 421c5c8663..9745a1b36d 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -23,7 +23,7 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" -#include "core/optimizer/gather_to_split_fusion.h" +#include "core/optimizer/gather_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -103,6 +103,7 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps)); #if defined(USE_CUDA) || defined(USE_ROCM) // We are supposed to use execution provider as indicator, but here we don't have access to the registered EP at this point