diff --git a/onnxruntime/core/optimizer/gather_to_split_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc similarity index 56% rename from onnxruntime/core/optimizer/gather_to_split_fusion.cc rename to onnxruntime/core/optimizer/gather_fusion.cc index e0e5b1b2ef..272eb0c252 100644 --- a/onnxruntime/core/optimizer/gather_to_split_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/gather_to_split_fusion.h" +#include "core/optimizer/gather_fusion.h" #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" @@ -175,4 +175,119 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le return Status::OK(); } + +/* +Fuse Range->Gather to Slice. Slice kernel is faster than Gather kernel in this case, +and SliceGrad is much faster than GatherGrad. +*/ +Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion + Node& node = *p_node; + + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Range", {1, 11}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) { + continue; + } + + Node& gather_node = *graph.GetNode(node.OutputNodesBegin()->Index()); + if (!graph_utils::IsSupportedOptypeVersionAndDomain(gather_node, "Gather", {1, 11, 13}) || + !graph_utils::IsSupportedProvider(gather_node, GetCompatibleExecutionProviders())) { + continue; + } + + // Range's output is Gather's input[1]. + if (node.MutableOutputDefs()[0] != gather_node.MutableInputDefs()[1]) { + continue; + } + + InlinedVector> nodes_to_fuse{node, gather_node}; + + auto& range_input_defs = node.MutableInputDefs(); + ORT_ENFORCE(range_input_defs.size() == 3); + // Range's inputs are scalar, need unsqueeze to 1-D tensors. + ONNX_NAMESPACE::TypeProto unsqueeze_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( + node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type()); + unsqueeze_output_type.mutable_tensor_type()->set_elem_type(element_type); + unsqueeze_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); + + InlinedVector unsqueeze_outputs; + for (size_t i = 0; i < range_input_defs.size(); ++i) { + unsqueeze_outputs.emplace_back(&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("unsqueeze_output_" + std::to_string(i)), &unsqueeze_output_type)); + } + + // Unsqueeze before and after OpSet-13 have different schemas. + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + if (onnx_opset_version < 13) { + for (size_t i = 0; i < range_input_defs.size(); ++i) { + Node& unsqueeze_node = + graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze", + "Unsqueeze for Fused Gather nodes", {range_input_defs[i]}, {unsqueeze_outputs[i]}); + unsqueeze_node.AddAttribute("axes", std::vector{static_cast(0)}); + unsqueeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + } + } else { + ONNX_NAMESPACE::TensorProto unsqueeze_axes_initializer_proto; + unsqueeze_axes_initializer_proto.set_name(graph.GenerateNodeName("UnsqueezeAxesInitializer")); + unsqueeze_axes_initializer_proto.add_dims(static_cast(1)); + unsqueeze_axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + unsqueeze_axes_initializer_proto.add_int64_data(static_cast(0)); + NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializer(graph, unsqueeze_axes_initializer_proto); + + for (size_t i = 0; i < range_input_defs.size(); ++i) { + Node& unsqueeze_node = graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze", + "Unsqueeze for Fused Gather nodes", + {range_input_defs[i], unsqueeze_axes_arg}, {unsqueeze_outputs[i]}); + unsqueeze_node.SetExecutionProviderType(node.GetExecutionProviderType()); + } + } + + int64_t axis = 0; // Default value. + auto& attrs = gather_node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + } + + ONNX_NAMESPACE::TensorProto slice_axes_initializer_proto; + slice_axes_initializer_proto.set_name(graph.GenerateNodeName("SliceAxesInitializer")); + slice_axes_initializer_proto.add_dims(static_cast(1)); + slice_axes_initializer_proto.set_data_type(element_type); + // Tind of Slice can only support int32 and int64. + if (element_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) { + slice_axes_initializer_proto.add_int64_data(axis); + } else { + slice_axes_initializer_proto.add_int32_data(static_cast(axis)); + } + NodeArg* slice_axes_arg = &graph_utils::AddInitializer(graph, slice_axes_initializer_proto); + Node& slice_node = graph.AddNode(graph.GenerateNodeName("Slice"), "Slice", "Slice for Fused Gather nodes", + {gather_node.MutableInputDefs()[0], unsqueeze_outputs[0], unsqueeze_outputs[1], + slice_axes_arg, unsqueeze_outputs[2]}, + {gather_node.MutableOutputDefs()[0]}); + slice_node.SetExecutionProviderType(gather_node.GetExecutionProviderType()); + + for (Node& n : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, n); + graph.RemoveNode(n.Index()); + } + + modified = true; + } + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gather_to_split_fusion.h b/onnxruntime/core/optimizer/gather_fusion.h similarity index 64% rename from onnxruntime/core/optimizer/gather_to_split_fusion.h rename to onnxruntime/core/optimizer/gather_fusion.h index 4860393b47..bdb8c1be82 100644 --- a/onnxruntime/core/optimizer/gather_to_split_fusion.h +++ b/onnxruntime/core/optimizer/gather_fusion.h @@ -23,4 +23,17 @@ class GatherToSplitFusion : public GraphTransformer { bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const; }; +/** +@Class GatherToGliceFusion + +Fuse Range->Gather to Slice node. +*/ +class GatherToSliceFusion : public GraphTransformer { + public: + GatherToSliceFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherToSliceFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 58b64511d4..ec0e637bce 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -39,7 +39,7 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" -#include "core/optimizer/gather_to_split_fusion.h" +#include "core/optimizer/gather_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -268,6 +268,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 70a23dd622..fa4a1c2dec 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -39,7 +39,7 @@ #include "core/optimizer/embed_layer_norm_fusion.h" #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" -#include "core/optimizer/gather_to_split_fusion.h" +#include "core/optimizer/gather_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -6391,5 +6391,89 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { } } +TEST_F(GraphTransformationTests, GatherToSliceFusion) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + ASSERT_EQ(op_count_map["Range"], 1); + ASSERT_EQ(op_count_map["Gather"], 1); + }; + + // OpSet-12, Tind is int32. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{8, 8, 8, 8}}); + auto* range_input_1 = builder.MakeInitializer({}, {0}); + auto* range_input_2 = builder.MakeInitializer({}, {8}); + auto* range_input_3 = builder.MakeInitializer({}, {1}); + auto* range_output = builder.MakeIntermediate(); + auto* gather_output = builder.MakeOutput(); + + builder.AddNode("Range", {range_input_1, range_input_2, range_input_3}, {range_output}); + builder.AddNode("Gather", {data_arg, range_output}, {gather_output}) + .AddAttribute("axis", static_cast(2)); + }; + + auto post_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + ASSERT_EQ(op_count_map["Range"], 0); + ASSERT_EQ(op_count_map["Gather"], 0); + ASSERT_EQ(op_count_map["Slice"], 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Slice") { + const NodeArg& input_arg = *(node.InputDefs()[3]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + ASSERT_TRUE(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + ASSERT_TRUE(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32); + ASSERT_EQ(2, *(init_const.data())); + } + } + }; + + std::unique_ptr transformer = std::make_unique(); + TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } + + // OpSet-14, Tind is int64. + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{8, 8, 8, 8}}); + auto* range_input_1 = builder.MakeInitializer({}, {static_cast(0)}); + auto* range_input_2 = builder.MakeInitializer({}, {static_cast(8)}); + auto* range_input_3 = builder.MakeInitializer({}, {static_cast(1)}); + auto* range_output = builder.MakeIntermediate(); + auto* gather_output = builder.MakeOutput(); + + builder.AddNode("Range", {range_input_1, range_input_2, range_input_3}, {range_output}); + builder.AddNode("Gather", {data_arg, range_output}, {gather_output}) + .AddAttribute("axis", static_cast(2)); + }; + + auto post_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + ASSERT_EQ(op_count_map["Range"], 0); + ASSERT_EQ(op_count_map["Gather"], 0); + ASSERT_EQ(op_count_map["Slice"], 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Slice") { + const NodeArg& input_arg = *(node.InputDefs()[3]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = + graph_utils::GetConstantInitializer(graph, input_arg.Name()); + ASSERT_TRUE(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + ASSERT_TRUE(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + ASSERT_EQ(2, static_cast(*(init_const.data()))); + } + } + }; + + std::unique_ptr transformer = std::make_unique(); + TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker); + } +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 421c5c8663..9745a1b36d 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -23,7 +23,7 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" -#include "core/optimizer/gather_to_split_fusion.h" +#include "core/optimizer/gather_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -103,6 +103,7 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps)); #if defined(USE_CUDA) || defined(USE_ROCM) // We are supposed to use execution provider as indicator, but here we don't have access to the registered EP at this point