mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Zhijxu/fix conv1d replacement (#19758)
remove the constraint - "group number should be less than 3"; add more condition to make sure the conv1d replacement only happens on conv1d instead of conv2d/conv3d; add more tests;
This commit is contained in:
parent
0cdf36faeb
commit
2a5c9b86eb
2 changed files with 99 additions and 34 deletions
|
|
@ -42,30 +42,45 @@
|
|||
*/
|
||||
namespace onnxruntime {
|
||||
bool NodeCanBeReplacedByMatmul(const Node& node) {
|
||||
// If node type is Conv, and attr "dilations" is 1, "kernel_shape" is 1, "stride" is 1, group is 1 or 2,
|
||||
// then it can be replaced by MatMul
|
||||
// Kernel_shape is 1 means it is conv1d
|
||||
/*
|
||||
If node type is Conv, and satisfy the following conditions then it can be replaced by MatMul:
|
||||
- not bias as input which means only has 2 inputs: input and weight
|
||||
- "dilations" should be [1]
|
||||
size 1 means conv1d
|
||||
- "strides" should be [1]
|
||||
- "pads" should be [0,0]
|
||||
- "autopad" should be "NOTSET"
|
||||
- "kernel_shape" should be [1]
|
||||
*/
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11})) {
|
||||
return false;
|
||||
}
|
||||
const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations");
|
||||
const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape");
|
||||
const auto* stride = graph_utils::GetNodeAttribute(node, "strides");
|
||||
const auto* group = graph_utils::GetNodeAttribute(node, "group");
|
||||
if (dilations == nullptr || kernel_shape == nullptr || stride == nullptr || group == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if ((dilations->ints_size() && dilations->ints(0) != 1) ||
|
||||
(kernel_shape->ints_size() && kernel_shape->ints(0) != 1) ||
|
||||
(stride->ints_size() && stride->ints(0) != 1) ||
|
||||
group->i() >= 3) {
|
||||
|
||||
// TODO: bias input can also be supported if needed
|
||||
if (node.InputDefs().size() != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations");
|
||||
const auto* strides = graph_utils::GetNodeAttribute(node, "strides");
|
||||
const auto* pads = graph_utils::GetNodeAttribute(node, "pads");
|
||||
const auto* autopad = graph_utils::GetNodeAttribute(node, "auto_pad");
|
||||
const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape");
|
||||
if (dilations == nullptr || strides == nullptr || pads == nullptr || autopad == nullptr || kernel_shape == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((dilations->ints_size() == 1 && dilations->ints(0) == 1) &&
|
||||
(strides->ints_size() == 1 && strides->ints(0) == 1) &&
|
||||
(autopad->s() == "NOTSET") &&
|
||||
(pads->ints_size() == 2 && pads->ints(0) == 0 && pads->ints(1) == 0) &&
|
||||
(kernel_shape->ints_size() == 1 && kernel_shape->ints(0) == 1)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void Conv1dToMatmul(Graph& graph, Node& conv) {
|
||||
void Conv1dToMatmul(Graph& graph, Node& conv, const std::string transformer_name) {
|
||||
// Shape of conv1d input: [batch_size, in_channels, in_length]
|
||||
// Shape of conv1d weight:[output_channels, input_channels/group, kernel_shape], kernel_shape is 1
|
||||
// We need to split the input into "group", and squeeze&split the weight, and then do MatMul
|
||||
|
|
@ -83,7 +98,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
|
|||
conv1d_input_splitted_outputs.push_back(&graph.GetOrCreateNodeArg(
|
||||
graph.GenerateNodeArgName("input_split_output"), nullptr));
|
||||
}
|
||||
auto& input_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, {conv1d_input},
|
||||
auto& input_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description, {conv1d_input},
|
||||
{conv1d_input_splitted_outputs});
|
||||
input_split.SetExecutionProviderType(execution_provider_type);
|
||||
input_split.AddAttribute("axis", int64_t(1));
|
||||
|
|
@ -93,23 +108,25 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
|
|||
}
|
||||
// 2. Squeeze conv weight
|
||||
auto conv1d_weight = conv.MutableInputDefs()[1];
|
||||
// auto con1d_bias = xx;
|
||||
auto weight_squeeze_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("weight_squeeze_output"), nullptr);
|
||||
auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName("WeightSqueeze"), "Squeeze",
|
||||
auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName(transformer_name + "WeightSqueeze"), "Squeeze",
|
||||
node_description, {conv1d_weight}, {weight_squeeze_output});
|
||||
int64_t weight_squeeze_axis = 2;
|
||||
if (onnx_opset_version > 12) {
|
||||
// After onnx version 12, squeeze node has axes as input instead of attribute
|
||||
ONNX_NAMESPACE::TensorProto initializer_proto;
|
||||
initializer_proto.set_name(graph.GenerateNodeName("ConstAsInitializer"));
|
||||
initializer_proto.set_name(graph.GenerateNodeName(transformer_name + "ConstAsInitializer"));
|
||||
initializer_proto.add_dims(static_cast<int64_t>(1));
|
||||
initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
|
||||
InlinedVector<int64_t> initializer_proto_value{2};
|
||||
InlinedVector<int64_t> initializer_proto_value{weight_squeeze_axis};
|
||||
initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t));
|
||||
auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto);
|
||||
// Squeeze node doesn't have opschema here, so we need to set input args count manually
|
||||
weight_squeeze.MutableInputArgsCount().resize(2);
|
||||
graph_utils::AddNodeInput(weight_squeeze, 1, axes_input);
|
||||
} else {
|
||||
weight_squeeze.AddAttribute("axes", std::vector<int64_t>{2});
|
||||
weight_squeeze.AddAttribute("axes", std::vector<int64_t>{weight_squeeze_axis});
|
||||
}
|
||||
weight_squeeze.SetExecutionProviderType(execution_provider_type);
|
||||
// 3. Split conv weight
|
||||
|
|
@ -118,7 +135,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
|
|||
conv1d_weight_splitted_outputs.push_back(&graph.GetOrCreateNodeArg(
|
||||
graph.GenerateNodeArgName("weight_split_output"), nullptr));
|
||||
}
|
||||
auto& weight_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description,
|
||||
auto& weight_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description,
|
||||
{weight_squeeze_output}, {conv1d_weight_splitted_outputs});
|
||||
weight_split.AddAttribute("axis", int64_t(0));
|
||||
weight_split.SetExecutionProviderType(execution_provider_type);
|
||||
|
|
@ -130,13 +147,13 @@ void Conv1dToMatmul(Graph& graph, Node& conv) {
|
|||
for (int i = 0; i < group_num; i++) {
|
||||
auto matmul_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("matmul_output"), nullptr);
|
||||
matmul_outputs.push_back(matmul_output);
|
||||
auto& matmul = graph.AddNode(graph.GenerateNodeName("Matmul"), "MatMul", node_description,
|
||||
auto& matmul = graph.AddNode(graph.GenerateNodeName(transformer_name + "Matmul"), "MatMul", node_description,
|
||||
{conv1d_weight_splitted_outputs[i], conv1d_input_splitted_outputs[i]},
|
||||
{matmul_output});
|
||||
matmul.SetExecutionProviderType(execution_provider_type);
|
||||
}
|
||||
// 5. Concat matmul outputs
|
||||
auto& concat_node = graph.AddNode(graph.GenerateNodeName("Concat"), "Concat", node_description,
|
||||
auto& concat_node = graph.AddNode(graph.GenerateNodeName(transformer_name + "Concat"), "Concat", node_description,
|
||||
matmul_outputs, {});
|
||||
concat_node.SetExecutionProviderType(execution_provider_type);
|
||||
concat_node.AddAttribute("axis", int64_t(1));
|
||||
|
|
@ -155,7 +172,7 @@ Status Conv1dReplacement::ApplyImpl(Graph& graph, bool& modified, int graph_leve
|
|||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
if (NodeCanBeReplacedByMatmul(node)) {
|
||||
LOGS(logger, VERBOSE) << "lora conv1d replacement, node name: " + node.Name();
|
||||
Conv1dToMatmul(graph, node);
|
||||
Conv1dToMatmul(graph, node, Name());
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1200,7 +1200,7 @@ TEST_P(QDQFusionTestsParameterized, CheckModelComposition) {
|
|||
ASSERT_EQ(op_to_count_post_fusion["com.microsoft.FakeQuant"], 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, Conv1dReplacement) {
|
||||
TEST_F(GraphTransformationTests, Conv1dReplacement_TakeEffect) {
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
auto op_count_map = CountOpsInGraph(graph);
|
||||
TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
|
||||
|
|
@ -1208,7 +1208,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) {
|
|||
};
|
||||
|
||||
for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
|
||||
for (auto group : {1, 2}) {
|
||||
for (auto group : {1, 2, 4}) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
|
||||
auto out_channel = 64;
|
||||
|
|
@ -1222,6 +1222,8 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) {
|
|||
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
|
||||
conv_node.AddAttribute("strides", std::vector<int64_t>{1});
|
||||
conv_node.AddAttribute("group", static_cast<int64_t>(group));
|
||||
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0});
|
||||
conv_node.AddAttribute("auto_pad", "NOTSET");
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
|
|
@ -1243,28 +1245,31 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) {
|
||||
// node has bias input so conv not replaced
|
||||
TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect1) {
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
auto op_count_map = CountOpsInGraph(graph);
|
||||
TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// "group" is 3 so conv not replaced
|
||||
for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
|
||||
auto out_channel = 64;
|
||||
auto* data_arg = builder.MakeInput<float>({{batch_size, in_channel, in_length}});
|
||||
|
||||
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel / 3, 1}, {-1.0f, 1.0f});
|
||||
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, {-1.0f, 1.0f});
|
||||
auto* bias_arg = builder.MakeInitializer<float>({out_channel}, {-1.0f, 1.0f});
|
||||
auto* conv_output = builder.MakeOutput();
|
||||
|
||||
auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output});
|
||||
auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg, bias_arg}, {conv_output});
|
||||
conv_node.AddAttribute("dilations", std::vector<int64_t>{1});
|
||||
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
|
||||
conv_node.AddAttribute("strides", std::vector<int64_t>{1});
|
||||
conv_node.AddAttribute("group", static_cast<int64_t>(3));
|
||||
conv_node.AddAttribute("group", static_cast<int64_t>(1));
|
||||
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0});
|
||||
conv_node.AddAttribute("auto_pad", "NOTSET");
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<Conv1dReplacement>();
|
||||
|
|
@ -1272,8 +1277,16 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) {
|
|||
TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, pre_graph_checker));
|
||||
}
|
||||
}
|
||||
|
||||
// "auto_pad " is not NOTSET so conv not replaced
|
||||
TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect2) {
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
auto op_count_map = CountOpsInGraph(graph);
|
||||
TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// "kernel_shape" is not 1 so conv not replaced
|
||||
for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
|
||||
|
|
@ -1285,9 +1298,44 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) {
|
|||
|
||||
auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output});
|
||||
conv_node.AddAttribute("dilations", std::vector<int64_t>{1});
|
||||
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{2});
|
||||
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
|
||||
conv_node.AddAttribute("strides", std::vector<int64_t>{1});
|
||||
conv_node.AddAttribute("group", static_cast<int64_t>(1));
|
||||
conv_node.AddAttribute("pads", std::vector<int64_t>{0, 0});
|
||||
conv_node.AddAttribute("auto_pad", "VALID");
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<Conv1dReplacement>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer),
|
||||
TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, pre_graph_checker));
|
||||
}
|
||||
}
|
||||
|
||||
// pads is not all zero, so conv not replaced
|
||||
TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect3) {
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
auto op_count_map = CountOpsInGraph(graph);
|
||||
TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128);
|
||||
auto out_channel = 64;
|
||||
auto* data_arg = builder.MakeInput<float>({{batch_size, in_channel, in_length}});
|
||||
|
||||
auto* weight_arg = builder.MakeInitializer<float>({out_channel, in_channel, 1}, {-1.0f, 1.0f});
|
||||
auto* conv_output = builder.MakeOutput();
|
||||
|
||||
auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output});
|
||||
conv_node.AddAttribute("dilations", std::vector<int64_t>{1});
|
||||
conv_node.AddAttribute("kernel_shape", std::vector<int64_t>{1});
|
||||
conv_node.AddAttribute("strides", std::vector<int64_t>{1});
|
||||
conv_node.AddAttribute("group", static_cast<int64_t>(1));
|
||||
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 0});
|
||||
conv_node.AddAttribute("auto_pad", "NOTSET");
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<Conv1dReplacement>();
|
||||
|
|
|
|||
Loading…
Reference in a new issue