Gather to Slice Fusion (#13599)

This PR is to optimize the running for below code from Huggingface's
XLNet model.
```
x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
```

The code will be exported to Range->Gather, which can be fused to a
Slice Op. Slice kernel is much faster than Gather, especially for
backward run. The main reason is for Gather, the data in indices can be
duplicated so that it needs sum during backward, but Slice node cannot
have such case.

Use Huggingface's XLNet model for profiling.
- Before the fuse
forward, ~753us

![image](https://user-images.githubusercontent.com/11661208/200758439-63f2f9b5-9610-4df8-98c8-a1ad4dc62f4e.png)
backward, ~46101us

![image](https://user-images.githubusercontent.com/11661208/200758530-fe16a8ec-ea8f-4b79-b3ac-386b72ba1670.png)

- After the fuse
forward, ~627us

![image](https://user-images.githubusercontent.com/11661208/200758654-ab9a6068-c45d-40f4-9c71-3862a56732f8.png)
backward, ~677us

![image](https://user-images.githubusercontent.com/11661208/200758833-aab1b8e1-1b5d-4e55-88cf-03c2a1d9d42b.png)
This commit is contained in:
Vincent Wang 2022-11-10 13:03:30 +08:00 committed by GitHub
parent 0511443782
commit 2bda3fd341
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 218 additions and 4 deletions

View file

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/gather_to_split_fusion.h"
#include "core/optimizer/gather_fusion.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
@ -175,4 +175,119 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le
return Status::OK();
}
/*
Fuse Range->Gather to Slice. Slice kernel is faster than Gather kernel in this case,
and SliceGrad is much faster than GatherGrad.
*/
Status GatherToSliceFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
auto* p_node = graph.GetNode(node_index);
if (p_node == nullptr) continue; // we removed the node as part of an earlier fusion
Node& node = *p_node;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Range", {1, 11}) ||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || node.GetOutputEdgesCount() != 1) {
continue;
}
Node& gather_node = *graph.GetNode(node.OutputNodesBegin()->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(gather_node, "Gather", {1, 11, 13}) ||
!graph_utils::IsSupportedProvider(gather_node, GetCompatibleExecutionProviders())) {
continue;
}
// Range's output is Gather's input[1].
if (node.MutableOutputDefs()[0] != gather_node.MutableInputDefs()[1]) {
continue;
}
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse{node, gather_node};
auto& range_input_defs = node.MutableInputDefs();
ORT_ENFORCE(range_input_defs.size() == 3);
// Range's inputs are scalar, need unsqueeze to 1-D tensors.
ONNX_NAMESPACE::TypeProto unsqueeze_output_type;
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(
node.MutableOutputDefs()[0]->TypeAsProto()->tensor_type().elem_type());
unsqueeze_output_type.mutable_tensor_type()->set_elem_type(element_type);
unsqueeze_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL);
InlinedVector<NodeArg*> unsqueeze_outputs;
for (size_t i = 0; i < range_input_defs.size(); ++i) {
unsqueeze_outputs.emplace_back(&graph.GetOrCreateNodeArg(
graph.GenerateNodeArgName("unsqueeze_output_" + std::to_string(i)), &unsqueeze_output_type));
}
// Unsqueeze before and after OpSet-13 have different schemas.
int onnx_opset_version = -1;
if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) {
onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain);
}
if (onnx_opset_version < 13) {
for (size_t i = 0; i < range_input_defs.size(); ++i) {
Node& unsqueeze_node =
graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze",
"Unsqueeze for Fused Gather nodes", {range_input_defs[i]}, {unsqueeze_outputs[i]});
unsqueeze_node.AddAttribute("axes", std::vector<int64_t>{static_cast<int64_t>(0)});
unsqueeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
}
} else {
ONNX_NAMESPACE::TensorProto unsqueeze_axes_initializer_proto;
unsqueeze_axes_initializer_proto.set_name(graph.GenerateNodeName("UnsqueezeAxesInitializer"));
unsqueeze_axes_initializer_proto.add_dims(static_cast<int64_t>(1));
unsqueeze_axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
unsqueeze_axes_initializer_proto.add_int64_data(static_cast<int64_t>(0));
NodeArg* unsqueeze_axes_arg = &graph_utils::AddInitializer(graph, unsqueeze_axes_initializer_proto);
for (size_t i = 0; i < range_input_defs.size(); ++i) {
Node& unsqueeze_node = graph.AddNode(graph.GenerateNodeName("Unsqueeze_" + std::to_string(i)), "Unsqueeze",
"Unsqueeze for Fused Gather nodes",
{range_input_defs[i], unsqueeze_axes_arg}, {unsqueeze_outputs[i]});
unsqueeze_node.SetExecutionProviderType(node.GetExecutionProviderType());
}
}
int64_t axis = 0; // Default value.
auto& attrs = gather_node.GetAttributes();
if (attrs.find("axis") != attrs.end()) {
auto& axis_attr = attrs.at("axis");
if (utils::HasInt(axis_attr)) axis = axis_attr.i();
}
ONNX_NAMESPACE::TensorProto slice_axes_initializer_proto;
slice_axes_initializer_proto.set_name(graph.GenerateNodeName("SliceAxesInitializer"));
slice_axes_initializer_proto.add_dims(static_cast<int64_t>(1));
slice_axes_initializer_proto.set_data_type(element_type);
// Tind of Slice can only support int32 and int64.
if (element_type == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
slice_axes_initializer_proto.add_int64_data(axis);
} else {
slice_axes_initializer_proto.add_int32_data(static_cast<int32_t>(axis));
}
NodeArg* slice_axes_arg = &graph_utils::AddInitializer(graph, slice_axes_initializer_proto);
Node& slice_node = graph.AddNode(graph.GenerateNodeName("Slice"), "Slice", "Slice for Fused Gather nodes",
{gather_node.MutableInputDefs()[0], unsqueeze_outputs[0], unsqueeze_outputs[1],
slice_axes_arg, unsqueeze_outputs[2]},
{gather_node.MutableOutputDefs()[0]});
slice_node.SetExecutionProviderType(gather_node.GetExecutionProviderType());
for (Node& n : nodes_to_fuse) {
graph_utils::RemoveNodeOutputEdges(graph, n);
graph.RemoveNode(n.Index());
}
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -23,4 +23,17 @@ class GatherToSplitFusion : public GraphTransformer {
bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis) const;
};
/**
@Class GatherToGliceFusion
Fuse Range->Gather to Slice node.
*/
class GatherToSliceFusion : public GraphTransformer {
public:
GatherToSliceFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GatherToSliceFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -39,7 +39,7 @@
#include "core/optimizer/expand_elimination.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/free_dim_override_transformer.h"
#include "core/optimizer/gather_to_split_fusion.h"
#include "core/optimizer/gather_fusion.h"
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
@ -268,6 +268,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<GatherToSplitFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<MatmulTransposeFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_rocm_eps));

View file

@ -39,7 +39,7 @@
#include "core/optimizer/embed_layer_norm_fusion.h"
#include "core/optimizer/expand_elimination.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/gather_to_split_fusion.h"
#include "core/optimizer/gather_fusion.h"
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
@ -6391,5 +6391,89 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) {
}
}
TEST_F(GraphTransformationTests, GatherToSliceFusion) {
auto pre_graph_checker = [&](Graph& graph) {
auto op_count_map = CountOpsInGraph(graph);
ASSERT_EQ(op_count_map["Range"], 1);
ASSERT_EQ(op_count_map["Gather"], 1);
};
// OpSet-12, Tind is int32.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{8, 8, 8, 8}});
auto* range_input_1 = builder.MakeInitializer<int32_t>({}, {0});
auto* range_input_2 = builder.MakeInitializer<int32_t>({}, {8});
auto* range_input_3 = builder.MakeInitializer<int32_t>({}, {1});
auto* range_output = builder.MakeIntermediate();
auto* gather_output = builder.MakeOutput();
builder.AddNode("Range", {range_input_1, range_input_2, range_input_3}, {range_output});
builder.AddNode("Gather", {data_arg, range_output}, {gather_output})
.AddAttribute("axis", static_cast<int64_t>(2));
};
auto post_graph_checker = [&](Graph& graph) {
auto op_count_map = CountOpsInGraph(graph);
ASSERT_EQ(op_count_map["Range"], 0);
ASSERT_EQ(op_count_map["Gather"], 0);
ASSERT_EQ(op_count_map["Slice"], 1);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Slice") {
const NodeArg& input_arg = *(node.InputDefs()[3]);
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, input_arg.Name());
ASSERT_TRUE(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
ASSERT_TRUE(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT32);
ASSERT_EQ(2, *(init_const.data<int32_t>()));
}
}
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSliceFusion>();
TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// OpSet-14, Tind is int64.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* data_arg = builder.MakeInput<float>({{8, 8, 8, 8}});
auto* range_input_1 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(0)});
auto* range_input_2 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(8)});
auto* range_input_3 = builder.MakeInitializer<int64_t>({}, {static_cast<int64_t>(1)});
auto* range_output = builder.MakeIntermediate();
auto* gather_output = builder.MakeOutput();
builder.AddNode("Range", {range_input_1, range_input_2, range_input_3}, {range_output});
builder.AddNode("Gather", {data_arg, range_output}, {gather_output})
.AddAttribute("axis", static_cast<int64_t>(2));
};
auto post_graph_checker = [&](Graph& graph) {
auto op_count_map = CountOpsInGraph(graph);
ASSERT_EQ(op_count_map["Range"], 0);
ASSERT_EQ(op_count_map["Gather"], 0);
ASSERT_EQ(op_count_map["Slice"], 1);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "Slice") {
const NodeArg& input_arg = *(node.InputDefs()[3]);
const ONNX_NAMESPACE::TensorProto* tensor_proto =
graph_utils::GetConstantInitializer(graph, input_arg.Name());
ASSERT_TRUE(tensor_proto != nullptr);
Initializer init_const{*tensor_proto, graph.ModelPath()};
ASSERT_TRUE(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64);
ASSERT_EQ(2, static_cast<int32_t>(*(init_const.data<int64_t>())));
}
}
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<GatherToSliceFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
}
} // namespace test
} // namespace onnxruntime

View file

@ -23,7 +23,7 @@
#include "core/optimizer/expand_elimination.h"
#include "core/optimizer/fast_gelu_fusion.h"
#include "core/optimizer/free_dim_override_transformer.h"
#include "core/optimizer/gather_to_split_fusion.h"
#include "core/optimizer/gather_fusion.h"
#include "core/optimizer/gelu_approximation.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/optimizer/gemm_activation_fusion.h"
@ -103,6 +103,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(std::make_unique<QuickGeluFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<SoftmaxCrossEntropyLossInternalFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<GatherToSplitFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(compatible_eps));
#if defined(USE_CUDA) || defined(USE_ROCM)
// We are supposed to use execution provider as indicator, but here we don't have access to the registered EP at this point