From 2a5c9b86ebbdba8fb76f79de26524a2fdd2e5c2a Mon Sep 17 00:00:00 2001 From: zhijiang <43435212+zhijxu-MS@users.noreply.github.com> Date: Tue, 5 Mar 2024 10:11:19 +0800 Subject: [PATCH] Zhijxu/fix conv1d replacement (#19758) remove the constraint - "group number should be less than 3"; add more condition to make sure the conv1d replacement only happens on conv1d instead of conv2d/conv3d; add more tests; --- .../core/optimizer/conv1d_replacement.cc | 67 ++++++++++++------- .../test/optimizer/graph_transform_test.cc | 66 +++++++++++++++--- 2 files changed, 99 insertions(+), 34 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc index 0412000e04..ff220fcb06 100644 --- a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc @@ -42,30 +42,45 @@ */ namespace onnxruntime { bool NodeCanBeReplacedByMatmul(const Node& node) { - // If node type is Conv, and attr "dilations" is 1, "kernel_shape" is 1, "stride" is 1, group is 1 or 2, - // then it can be replaced by MatMul - // Kernel_shape is 1 means it is conv1d + /* + If node type is Conv, and satisfy the following conditions then it can be replaced by MatMul: + - not bias as input which means only has 2 inputs: input and weight + - "dilations" should be [1] + size 1 means conv1d + - "strides" should be [1] + - "pads" should be [0,0] + - "autopad" should be "NOTSET" + - "kernel_shape" should be [1] + */ if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11})) { return false; } - const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); - const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); - const auto* stride = graph_utils::GetNodeAttribute(node, "strides"); - const auto* group = graph_utils::GetNodeAttribute(node, "group"); - if (dilations == nullptr || kernel_shape == nullptr || stride == nullptr || group == nullptr) { - return false; - } - if ((dilations->ints_size() && dilations->ints(0) != 1) || - (kernel_shape->ints_size() && kernel_shape->ints(0) != 1) || - (stride->ints_size() && stride->ints(0) != 1) || - group->i() >= 3) { + + // TODO: bias input can also be supported if needed + if (node.InputDefs().size() != 2) { return false; } - return true; + const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); + const auto* strides = graph_utils::GetNodeAttribute(node, "strides"); + const auto* pads = graph_utils::GetNodeAttribute(node, "pads"); + const auto* autopad = graph_utils::GetNodeAttribute(node, "auto_pad"); + const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); + if (dilations == nullptr || strides == nullptr || pads == nullptr || autopad == nullptr || kernel_shape == nullptr) { + return false; + } + + if ((dilations->ints_size() == 1 && dilations->ints(0) == 1) && + (strides->ints_size() == 1 && strides->ints(0) == 1) && + (autopad->s() == "NOTSET") && + (pads->ints_size() == 2 && pads->ints(0) == 0 && pads->ints(1) == 0) && + (kernel_shape->ints_size() == 1 && kernel_shape->ints(0) == 1)) { + return true; + } + return false; } -void Conv1dToMatmul(Graph& graph, Node& conv) { +void Conv1dToMatmul(Graph& graph, Node& conv, const std::string transformer_name) { // Shape of conv1d input: [batch_size, in_channels, in_length] // Shape of conv1d weight:[output_channels, input_channels/group, kernel_shape], kernel_shape is 1 // We need to split the input into "group", and squeeze&split the weight, and then do MatMul @@ -83,7 +98,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { conv1d_input_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( graph.GenerateNodeArgName("input_split_output"), nullptr)); } - auto& input_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, {conv1d_input}, + auto& input_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description, {conv1d_input}, {conv1d_input_splitted_outputs}); input_split.SetExecutionProviderType(execution_provider_type); input_split.AddAttribute("axis", int64_t(1)); @@ -93,23 +108,25 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { } // 2. Squeeze conv weight auto conv1d_weight = conv.MutableInputDefs()[1]; + // auto con1d_bias = xx; auto weight_squeeze_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("weight_squeeze_output"), nullptr); - auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName("WeightSqueeze"), "Squeeze", + auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName(transformer_name + "WeightSqueeze"), "Squeeze", node_description, {conv1d_weight}, {weight_squeeze_output}); + int64_t weight_squeeze_axis = 2; if (onnx_opset_version > 12) { // After onnx version 12, squeeze node has axes as input instead of attribute ONNX_NAMESPACE::TensorProto initializer_proto; - initializer_proto.set_name(graph.GenerateNodeName("ConstAsInitializer")); + initializer_proto.set_name(graph.GenerateNodeName(transformer_name + "ConstAsInitializer")); initializer_proto.add_dims(static_cast(1)); initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - InlinedVector initializer_proto_value{2}; + InlinedVector initializer_proto_value{weight_squeeze_axis}; initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t)); auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto); // Squeeze node doesn't have opschema here, so we need to set input args count manually weight_squeeze.MutableInputArgsCount().resize(2); graph_utils::AddNodeInput(weight_squeeze, 1, axes_input); } else { - weight_squeeze.AddAttribute("axes", std::vector{2}); + weight_squeeze.AddAttribute("axes", std::vector{weight_squeeze_axis}); } weight_squeeze.SetExecutionProviderType(execution_provider_type); // 3. Split conv weight @@ -118,7 +135,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { conv1d_weight_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( graph.GenerateNodeArgName("weight_split_output"), nullptr)); } - auto& weight_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, + auto& weight_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description, {weight_squeeze_output}, {conv1d_weight_splitted_outputs}); weight_split.AddAttribute("axis", int64_t(0)); weight_split.SetExecutionProviderType(execution_provider_type); @@ -130,13 +147,13 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { for (int i = 0; i < group_num; i++) { auto matmul_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("matmul_output"), nullptr); matmul_outputs.push_back(matmul_output); - auto& matmul = graph.AddNode(graph.GenerateNodeName("Matmul"), "MatMul", node_description, + auto& matmul = graph.AddNode(graph.GenerateNodeName(transformer_name + "Matmul"), "MatMul", node_description, {conv1d_weight_splitted_outputs[i], conv1d_input_splitted_outputs[i]}, {matmul_output}); matmul.SetExecutionProviderType(execution_provider_type); } // 5. Concat matmul outputs - auto& concat_node = graph.AddNode(graph.GenerateNodeName("Concat"), "Concat", node_description, + auto& concat_node = graph.AddNode(graph.GenerateNodeName(transformer_name + "Concat"), "Concat", node_description, matmul_outputs, {}); concat_node.SetExecutionProviderType(execution_provider_type); concat_node.AddAttribute("axis", int64_t(1)); @@ -155,7 +172,7 @@ Status Conv1dReplacement::ApplyImpl(Graph& graph, bool& modified, int graph_leve ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (NodeCanBeReplacedByMatmul(node)) { LOGS(logger, VERBOSE) << "lora conv1d replacement, node name: " + node.Name(); - Conv1dToMatmul(graph, node); + Conv1dToMatmul(graph, node, Name()); modified = true; } } diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index bab7c09839..109937ff96 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -1200,7 +1200,7 @@ TEST_P(QDQFusionTestsParameterized, CheckModelComposition) { ASSERT_EQ(op_to_count_post_fusion["com.microsoft.FakeQuant"], 1); } -TEST_F(GraphTransformationTests, Conv1dReplacement) { +TEST_F(GraphTransformationTests, Conv1dReplacement_TakeEffect) { auto pre_graph_checker = [&](Graph& graph) { auto op_count_map = CountOpsInGraph(graph); TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); @@ -1208,7 +1208,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { }; for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { - for (auto group : {1, 2}) { + for (auto group : {1, 2, 4}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); auto out_channel = 64; @@ -1222,6 +1222,8 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); conv_node.AddAttribute("group", static_cast(group)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); }; auto post_graph_checker = [&](Graph& graph) { @@ -1243,28 +1245,31 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { } } -TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { +// node has bias input so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect1) { auto pre_graph_checker = [&](Graph& graph) { auto op_count_map = CountOpsInGraph(graph); TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); return Status::OK(); }; - // "group" is 3 so conv not replaced for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); auto out_channel = 64; auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); - auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / 3, 1}, {-1.0f, 1.0f}); + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); + auto* bias_arg = builder.MakeInitializer({out_channel}, {-1.0f, 1.0f}); auto* conv_output = builder.MakeOutput(); - auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg, bias_arg}, {conv_output}); conv_node.AddAttribute("dilations", std::vector{1}); conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); - conv_node.AddAttribute("group", static_cast(3)); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); }; std::unique_ptr transformer = std::make_unique(); @@ -1272,8 +1277,16 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { TransformerLevel::Level1, 1, pre_graph_checker, pre_graph_checker)); } +} + +// "auto_pad " is not NOTSET so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect2) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; - // "kernel_shape" is not 1 so conv not replaced for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); @@ -1285,9 +1298,44 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); conv_node.AddAttribute("dilations", std::vector{1}); - conv_node.AddAttribute("kernel_shape", std::vector{2}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "VALID"); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } +} + +// pads is not all zero, so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect3) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{1, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); }; std::unique_ptr transformer = std::make_unique();