mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Gather to Slice Fusion (#13599)
This PR is to optimize the running for below code from Huggingface's XLNet model. ``` x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long)) ``` The code will be exported to Range->Gather, which can be fused to a Slice Op. Slice kernel is much faster than Gather, especially for backward run. The main reason is for Gather, the data in indices can be duplicated so that it needs sum during backward, but Slice node cannot have such case. Use Huggingface's XLNet model for profiling. - Before the fuse forward, ~753us  backward, ~46101us  - After the fuse forward, ~627us  backward, ~677us 
This commit is contained in:
parent
0511443782
commit
2bda3fd341
5 changed files with 218 additions and 4 deletions
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/gather_to_split_fusion.h"
|
||||
#include "core/optimizer/gather_fusion.h"
|
||||
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
|
|
@ -175,4 +175,119 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/*
|
||||
Fuse Range->Gather to Slice. Slice kernel is faster than Gather kernel in this case,
|
||||
and SliceGrad is much faster than GatherGrad.
|
||||
*/
|
||||
Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
||||
const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto* p_node = graph.GetNode(node_index);
|
||||
if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion
|
||||
Node& node = *p_node;
|
||||
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Range", {1, 11}) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node& gather_node = *graph.GetNode(node.OutputNodesBegin()->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(gather_node, "Gather", {1, 11, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(gather_node, GetCompatibleExecutionProviders())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Range's output is Gather's input[1].
|
||||
if (node.MutableOutputDefs()[0] != gather_node.MutableInputDefs()[1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse{node, gather_node};
|
||||
|
||||
auto& range_input_defs = node.MutableInputDefs();
|
||||
ORT_ENFORCE(range_input_defs.size() == 3);
|
||||
// Range's inputs are scalar, need unsqueeze to 1-D tensors.
|
||||
ONNX_NAMESPACE::TypeProto unsqueeze_output_type;
|
||||
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(
|
||||
node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type());
|
||||
unsqueeze_output_type.mutable_tensor_type()->set_elem_type(element_type);
|
||||
unsqueeze_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL);
|
||||
|
||||
InlinedVector<NodeArg*> unsqueeze_outputs;
|
||||
for (size_t i = 0; i < range_input_defs.size(); ++i) {
|
||||
unsqueeze_outputs.emplace_back(&graph.GetOrCreateNodeArg(
|
||||
graph.GenerateNodeArgName("unsqueeze_output_" + std::to_string(i)), &unsqueeze_output_type));
|
||||
}
|
||||
|
||||
// Unsqueeze before and after OpSet-13 have different schemas.
|
||||
int onnx_opset_version = -1;
|
||||
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
|
||||
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
|
||||
}
|
||||
|
||||
if (onnx_opset_version < 13) {
|
||||
for (size_t i = 0; i < range_input_defs.size(); ++i) {
|
||||
Node& unsqueeze_node =
|
||||
graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze",
|
||||
"Unsqueeze for Fused Gather nodes", {range_input_defs[i]}, {unsqueeze_outputs[i]});
|
||||
unsqueeze_node.AddAttribute("axes", std::vector<int64_t>{static_cast<int64_t>(0)});
|
||||
unsqueeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
}
|
||||
} else {
|
||||
ONNX_NAMESPACE::TensorProto unsqueeze_axes_initializer_proto;
|
||||
unsqueeze_axes_initializer_proto.set_name(graph.GenerateNodeName("UnsqueezeAxesInitializer"));
|
||||
unsqueeze_axes_initializer_proto.add_dims(static_cast<int64_t>(1));
|
||||
unsqueeze_axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
unsqueeze_axes_initializer_proto.add_int64_data(static_cast<int64_t>(0));
|
||||
NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializer(graph, unsqueeze_axes_initializer_proto);
|
||||
|
||||
for (size_t i = 0; i < range_input_defs.size(); ++i) {
|
||||
Node& unsqueeze_node = graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze",
|
||||
"Unsqueeze for Fused Gather nodes",
|
||||
{range_input_defs[i], unsqueeze_axes_arg}, {unsqueeze_outputs[i]});
|
||||
unsqueeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
}
|
||||
}
|
||||
|
||||
int64_t axis = 0; // Default value.
|
||||
auto& attrs = gather_node.GetAttributes();
|
||||
if (attrs.find("axis") != attrs.end()) {
|
||||
auto& axis_attr = attrs.at("axis");
|
||||
if (utils::HasInt(axis_attr)) axis = axis_attr.i();
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorProto slice_axes_initializer_proto;
|
||||
slice_axes_initializer_proto.set_name(graph.GenerateNodeName("SliceAxesInitializer"));
|
||||
slice_axes_initializer_proto.add_dims(static_cast<int64_t>(1));
|
||||
slice_axes_initializer_proto.set_data_type(element_type);
|
||||
// Tind of Slice can only support int32 and int64.
|
||||
if (element_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
|
||||
slice_axes_initializer_proto.add_int64_data(axis);
|
||||
} else {
|
||||
slice_axes_initializer_proto.add_int32_data(static_cast<int32_t>(axis));
|
||||
}
|
||||
NodeArg* slice_axes_arg = &graph_utils::AddInitializer(graph, slice_axes_initializer_proto);
|
||||
Node& slice_node = graph.AddNode(graph.GenerateNodeName("Slice"), "Slice", "Slice for Fused Gather nodes",
|
||||
{gather_node.MutableInputDefs()[0], unsqueeze_outputs[0], unsqueeze_outputs[1],
|
||||
slice_axes_arg, unsqueeze_outputs[2]},
|
||||
{gather_node.MutableOutputDefs()[0]});
|
||||
slice_node.SetExecutionProviderType(gather_node.GetExecutionProviderType());
|
||||
|
||||
for (Node& n : nodes_to_fuse) {
|
||||
graph_utils::RemoveNodeOutputEdges(graph, n);
|
||||
graph.RemoveNode(n.Index());
|
||||
}
|
||||
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -23,4 +23,17 @@ class GatherToSplitFusion : public GraphTransformer {
|
|||
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const;
|
||||
};
|
||||
|
||||
/**
|
||||
@Class GatherToGliceFusion
|
||||
|
||||
Fuse Range->Gather to Slice node.
|
||||
*/
|
||||
class GatherToSliceFusion : public GraphTransformer {
|
||||
public:
|
||||
GatherToSliceFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("GatherToSliceFusion", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -39,7 +39,7 @@
|
|||
#include "core/optimizer/expand_elimination.h"
|
||||
#include "core/optimizer/fast_gelu_fusion.h"
|
||||
#include "core/optimizer/free_dim_override_transformer.h"
|
||||
#include "core/optimizer/gather_to_split_fusion.h"
|
||||
#include "core/optimizer/gather_fusion.h"
|
||||
#include "core/optimizer/gelu_approximation.h"
|
||||
#include "core/optimizer/gelu_fusion.h"
|
||||
#include "core/optimizer/gemm_activation_fusion.h"
|
||||
|
|
@ -268,6 +268,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<GatherToSplitFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));
|
||||
|
||||
transformers.emplace_back(std::make_unique<MatmulTransposeFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_rocm_eps));
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@
|
|||
#include "core/optimizer/embed_layer_norm_fusion.h"
|
||||
#include "core/optimizer/expand_elimination.h"
|
||||
#include "core/optimizer/fast_gelu_fusion.h"
|
||||
#include "core/optimizer/gather_to_split_fusion.h"
|
||||
#include "core/optimizer/gather_fusion.h"
|
||||
#include "core/optimizer/gelu_approximation.h"
|
||||
#include "core/optimizer/gelu_fusion.h"
|
||||
#include "core/optimizer/gemm_activation_fusion.h"
|
||||
|
|
@ -6391,5 +6391,89 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, GatherToSliceFusion) {
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
auto op_count_map = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_count_map["Range"], 1);
|
||||
ASSERT_EQ(op_count_map["Gather"], 1);
|
||||
};
|
||||
|
||||
// OpSet-12, Tind is int32.
|
||||
{
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* data_arg = builder.MakeInput<float>({{8, 8, 8, 8}});
|
||||
auto* range_input_1 = builder.MakeInitializer<int32_t>({}, {0});
|
||||
auto* range_input_2 = builder.MakeInitializer<int32_t>({}, {8});
|
||||
auto* range_input_3 = builder.MakeInitializer<int32_t>({}, {1});
|
||||
auto* range_output = builder.MakeIntermediate();
|
||||
auto* gather_output = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Range", {range_input_1, range_input_2, range_input_3}, {range_output});
|
||||
builder.AddNode("Gather", {data_arg, range_output}, {gather_output})
|
||||
.AddAttribute("axis", static_cast<int64_t>(2));
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
auto op_count_map = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_count_map["Range"], 0);
|
||||
ASSERT_EQ(op_count_map["Gather"], 0);
|
||||
ASSERT_EQ(op_count_map["Slice"], 1);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Slice") {
|
||||
const NodeArg& input_arg = *(node.InputDefs()[3]);
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, input_arg.Name());
|
||||
ASSERT_TRUE(tensor_proto != nullptr);
|
||||
Initializer init_const{*tensor_proto, graph.ModelPath()};
|
||||
ASSERT_TRUE(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32);
|
||||
ASSERT_EQ(2, *(init_const.data<int32_t>()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSliceFusion>();
|
||||
TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
|
||||
// OpSet-14, Tind is int64.
|
||||
{
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* data_arg = builder.MakeInput<float>({{8, 8, 8, 8}});
|
||||
auto* range_input_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
|
||||
auto* range_input_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(8)});
|
||||
auto* range_input_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
|
||||
auto* range_output = builder.MakeIntermediate();
|
||||
auto* gather_output = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Range", {range_input_1, range_input_2, range_input_3}, {range_output});
|
||||
builder.AddNode("Gather", {data_arg, range_output}, {gather_output})
|
||||
.AddAttribute("axis", static_cast<int64_t>(2));
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
auto op_count_map = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_count_map["Range"], 0);
|
||||
ASSERT_EQ(op_count_map["Gather"], 0);
|
||||
ASSERT_EQ(op_count_map["Slice"], 1);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Slice") {
|
||||
const NodeArg& input_arg = *(node.InputDefs()[3]);
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto =
|
||||
graph_utils::GetConstantInitializer(graph, input_arg.Name());
|
||||
ASSERT_TRUE(tensor_proto != nullptr);
|
||||
Initializer init_const{*tensor_proto, graph.ModelPath()};
|
||||
ASSERT_TRUE(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
ASSERT_EQ(2, static_cast<int32_t>(*(init_const.data<int64_t>())));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSliceFusion>();
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
#include "core/optimizer/expand_elimination.h"
|
||||
#include "core/optimizer/fast_gelu_fusion.h"
|
||||
#include "core/optimizer/free_dim_override_transformer.h"
|
||||
#include "core/optimizer/gather_to_split_fusion.h"
|
||||
#include "core/optimizer/gather_fusion.h"
|
||||
#include "core/optimizer/gelu_approximation.h"
|
||||
#include "core/optimizer/gelu_fusion.h"
|
||||
#include "core/optimizer/gemm_activation_fusion.h"
|
||||
|
|
@ -103,6 +103,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
transformers.emplace_back(std::make_unique<QuickGeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<SoftmaxCrossEntropyLossInternalFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<GatherToSplitFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(compatible_eps));
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
// We are supposed to use execution provider as indicator, but here we don't have access to the registered EP at this point
|
||||
|
|
|
|||
Loading…
Reference in a new issue