From 16d7f5519318fc309eccdc957c14eba0bfe5bb10 Mon Sep 17 00:00:00 2001 From: zhijiang <43435212+zhijxu-MS@users.noreply.github.com> Date: Thu, 16 Nov 2023 17:08:06 +0800 Subject: [PATCH] lora conv1d replacement (#16643) in LoRA code, it will use conv1d to do projection for qkv, while the conv1d calculation is mathematically equivalent to matmul, and matmul is much faster than conv1d. The subsitution of the graph optimizer is: 1 conv1d >> 2 split + 1 squeeze + group_num matmul + 1 concat with this optimizer, we see 10%+ in one 1P model --- .../core/optimizer/conv1d_replacement.cc | 164 ++++++++++++++++++ .../core/optimizer/conv1d_replacement.h | 18 ++ .../core/optimizer/graph_transformer_utils.cc | 2 + .../test/optimizer/graph_transform_test.cc | 98 +++++++++++ 4 files changed, 282 insertions(+) create mode 100644 orttraining/orttraining/core/optimizer/conv1d_replacement.cc create mode 100644 orttraining/orttraining/core/optimizer/conv1d_replacement.h diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc new file mode 100644 index 0000000000..0412000e04 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/optimizer/initializer.h" +#include "orttraining/core/optimizer/conv1d_replacement.h" +#include "core/graph/graph_utils.h" + +/* + In LoRA code, it will use conv1d to do projection for qkv, + while the conv1d calculation is mathematically equivalent to MatMul, and MatMul is much faster than conv1d in GPU. + The graph transformation is doing the following graph substitution: + 1. The input graph is: + conv_input conv_weight + \ / + \ / + conv1d + + 2. The output graph is as follows, + the number of MatMul is equal to attribute "group" of conv1d + conv_input conv1d.group conv_weight conv1d.group + \ / \ / + \ / Squeeze / + \ / \ / + Split Split + / / ... \ / / ... \ + / / ... \ / / ... \ + / / ... \ / / ... \ + input0 input1 ... inputN weight0 weight1 ... weightN + \ \ \ / / / + \ \ \ / / / + \ \ \ / / / + \ \ X / / + \ \ / \ / / + \ \ / X / + \ X / \ / + \ / \ / \ / + MatMul MatMul ... MatMul + \ | ... / + \ | / + \ | / +*/ +namespace onnxruntime { +bool NodeCanBeReplacedByMatmul(const Node& node) { + // If node type is Conv, and attr "dilations" is 1, "kernel_shape" is 1, "stride" is 1, group is 1 or 2, + // then it can be replaced by MatMul + // Kernel_shape is 1 means it is conv1d + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11})) { + return false; + } + const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); + const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); + const auto* stride = graph_utils::GetNodeAttribute(node, "strides"); + const auto* group = graph_utils::GetNodeAttribute(node, "group"); + if (dilations == nullptr || kernel_shape == nullptr || stride == nullptr || group == nullptr) { + return false; + } + if ((dilations->ints_size() && dilations->ints(0) != 1) || + (kernel_shape->ints_size() && kernel_shape->ints(0) != 1) || + (stride->ints_size() && stride->ints(0) != 1) || + group->i() >= 3) { + return false; + } + + return true; +} + +void Conv1dToMatmul(Graph& graph, Node& conv) { + // Shape of conv1d input: [batch_size, in_channels, in_length] + // Shape of conv1d weight:[output_channels, input_channels/group, kernel_shape], kernel_shape is 1 + // We need to split the input into "group", and squeeze&split the weight, and then do MatMul + const std::string node_description("Conv1dReplacement"); + auto execution_provider_type = conv.GetExecutionProviderType(); + // 1. Split conv input + auto group_attr = graph_utils::GetNodeAttribute(conv, "group"); + int64_t group_num = 1; // default group is 1 from ONNX schema + if (group_attr != nullptr) { + group_num = group_attr->i(); + } + auto conv1d_input = conv.MutableInputDefs()[0]; + std::vector conv1d_input_splitted_outputs; + for (int i = 0; i < group_num; i++) { + conv1d_input_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("input_split_output"), nullptr)); + } + auto& input_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, {conv1d_input}, + {conv1d_input_splitted_outputs}); + input_split.SetExecutionProviderType(execution_provider_type); + input_split.AddAttribute("axis", int64_t(1)); + auto onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + if (onnx_opset_version >= 18) { + input_split.AddAttribute("num_outputs", group_num); + } + // 2. Squeeze conv weight + auto conv1d_weight = conv.MutableInputDefs()[1]; + auto weight_squeeze_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("weight_squeeze_output"), nullptr); + auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName("WeightSqueeze"), "Squeeze", + node_description, {conv1d_weight}, {weight_squeeze_output}); + if (onnx_opset_version > 12) { + // After onnx version 12, squeeze node has axes as input instead of attribute + ONNX_NAMESPACE::TensorProto initializer_proto; + initializer_proto.set_name(graph.GenerateNodeName("ConstAsInitializer")); + initializer_proto.add_dims(static_cast(1)); + initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + InlinedVector initializer_proto_value{2}; + initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t)); + auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto); + // Squeeze node doesn't have opschema here, so we need to set input args count manually + weight_squeeze.MutableInputArgsCount().resize(2); + graph_utils::AddNodeInput(weight_squeeze, 1, axes_input); + } else { + weight_squeeze.AddAttribute("axes", std::vector{2}); + } + weight_squeeze.SetExecutionProviderType(execution_provider_type); + // 3. Split conv weight + std::vector conv1d_weight_splitted_outputs; + for (int i = 0; i < group_num; i++) { + conv1d_weight_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("weight_split_output"), nullptr)); + } + auto& weight_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, + {weight_squeeze_output}, {conv1d_weight_splitted_outputs}); + weight_split.AddAttribute("axis", int64_t(0)); + weight_split.SetExecutionProviderType(execution_provider_type); + if (onnx_opset_version >= 18) { + weight_split.AddAttribute("num_outputs", group_num); + } + // 4. Do MatMul + std::vector matmul_outputs; + for (int i = 0; i < group_num; i++) { + auto matmul_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("matmul_output"), nullptr); + matmul_outputs.push_back(matmul_output); + auto& matmul = graph.AddNode(graph.GenerateNodeName("Matmul"), "MatMul", node_description, + {conv1d_weight_splitted_outputs[i], conv1d_input_splitted_outputs[i]}, + {matmul_output}); + matmul.SetExecutionProviderType(execution_provider_type); + } + // 5. Concat matmul outputs + auto& concat_node = graph.AddNode(graph.GenerateNodeName("Concat"), "Concat", node_description, + matmul_outputs, {}); + concat_node.SetExecutionProviderType(execution_provider_type); + concat_node.AddAttribute("axis", int64_t(1)); + // 6. Clean up - delted original "conv" node, its output is replaced by concat_node + graph_utils::FinalizeNodeFusion(graph, concat_node, conv); +} + +Status Conv1dReplacement::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (!node_ptr) + continue; // node was removed + auto& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + if (NodeCanBeReplacedByMatmul(node)) { + LOGS(logger, VERBOSE) << "lora conv1d replacement, node name: " + node.Name(); + Conv1dToMatmul(graph, node); + modified = true; + } + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.h b/orttraining/orttraining/core/optimizer/conv1d_replacement.h new file mode 100644 index 0000000000..740f13c76f --- /dev/null +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +class Conv1dReplacement : public GraphTransformer { + public: + Conv1dReplacement(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("Conv1dReplacement", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 57d76577f1..6193a1d10c 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -72,6 +72,7 @@ #ifdef ENABLE_TRAINING_TORCH_INTEROP #include "orttraining/core/optimizer/pythonop_rewriter.h" #endif +#include "orttraining/core/optimizer/conv1d_replacement.h" namespace onnxruntime { namespace training { @@ -194,6 +195,7 @@ std::vector> GeneratePreTrainingTransformers( // Once we have a CPU kernel for PadAndUnflatten, we can remove the guard. transformers.emplace_back(std::make_unique(compatible_eps, config.sparse_embedding_input_names)); + transformers.emplace_back(std::make_unique(compatible_eps)); #endif } diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 20b9354d85..b774fec11c 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -35,6 +35,7 @@ #ifdef ENABLE_TRITON #include "orttraining/core/optimizer/triton_fusion.h" #endif +#include "orttraining/core/optimizer/conv1d_replacement.h" #include @@ -1199,6 +1200,103 @@ TEST_P(QDQFusionTestsParameterized, CheckModelComposition) { ASSERT_EQ(op_to_count_post_fusion["com.microsoft.FakeQuant"], 1); } +TEST_F(GraphTransformationTests, Conv1dReplacement) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + for (auto group : {1, 2}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / group, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(group)); + }; + + auto post_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 0); + // after graph transformation, the graph should have 1 squeeze, 2 split, group matmul, 1 concat + TEST_RETURN_IF_NOT(op_count_map["Squeeze"] == 1); + TEST_RETURN_IF_NOT(op_count_map["Split"] == 2); + TEST_RETURN_IF_NOT(op_count_map["MatMul"] == group); + TEST_RETURN_IF_NOT(op_count_map["Concat"] == 1); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, post_graph_checker)); + } + } +} + +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + // "group" is 3 so conv not replaced + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / 3, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(3)); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } + + // "kernel_shape" is not 1 so conv not replaced + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{2}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(1)); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } +} + INSTANTIATE_TEST_SUITE_P( QDQFusionTests, QDQFusionTestsParameterized,