diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index 6f3d3414ca..2c43f5ab12 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -131,6 +131,11 @@ std::vector GetInputNodeMerges( ORT_ENFORCE(input_node.InputDefs().size() == 2 && scale_and_index->second < 2); const int to_scale_index = 1 - scale_and_index->second; + // check if the non-scale input is scalar + const auto non_scale_node_arg = input_node.InputDefs()[to_scale_index]; + const auto* shape = non_scale_node_arg->Shape(); + if (shape == nullptr || shape->dim_size() == 0) continue; + input_node_merges.push_back( {input_edge, scale_and_index->first, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 591470fdd6..56880054b3 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4727,6 +4727,18 @@ TEST_F(GraphTransformationTests, MatMulScaleFusionUnsupportedInputType) { {kCpuExecutionProvider}); } +TEST_F(GraphTransformationTests, MatMulScaleFusionWithScaleInput) { + TestMatMulScaleFusion( + MODEL_FOLDER "fusion/matmul_scale_with_scale_input.onnx", *logger_, + [](const Graph&, + const std::map&, + std::map transformed_op_counts) { + EXPECT_EQ(transformed_op_counts["Mul"], 1); + EXPECT_EQ(transformed_op_counts["MatMul"], 1); + EXPECT_EQ(transformed_op_counts["com.microsoft.FusedMatMul"], 0); + }); +} + #if defined(USE_CUDA) || defined(USE_ROCM) TEST_F(GraphTransformationTests, IsInfReduceSum_Test) { auto model_uri = MODEL_FOLDER "fusion/isinf_reducesum.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py b/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py index 89e6baec71..850ff6bc22 100644 --- a/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py @@ -16,7 +16,7 @@ opsets = [onnxdomain, msdomain] scale_value = 3.0 -def save(model_path, nodes, inputs, outputs, initializers): +def save(model_path, nodes, inputs, outputs, initializers, opsets=opsets): graph = helper.make_graph( nodes, "MatMulScaleTest", @@ -218,3 +218,40 @@ def gen_int32(model_path): gen_int32("matmul_scale_int32.onnx") + + +def gen_scale_input(model_path): + + nodes = [ + helper.make_node( + "Mul", ["input_0", "scale"], ["scaled_input_0"], + "scale input_0"), + helper.make_node( + "MatMul", ["scaled_input_0", "input_1"], ["output_0"], + "MatMul input_0 and input_1"), + ] + + initializers = [ + helper.make_tensor("scale", TensorProto.FLOAT, [1], [1.0]) + ] + + inputs = [ + helper.make_tensor_value_info( + "input_0", TensorProto.FLOAT, []), + helper.make_tensor_value_info( + "input_1", TensorProto.FLOAT, [1, 'K']), + ] + + outputs = [ + helper.make_tensor_value_info( + "output_0", TensorProto.FLOAT, ['K']), + ] + + onnxdomain = OperatorSetIdProto() + onnxdomain.version = 14 + onnxdomain.domain = "" + opsets = [onnxdomain, msdomain] + save(model_path, nodes, inputs, outputs, initializers, opsets) + + +gen_scale_input("matmul_scale_with_scale_input.onnx") \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx new file mode 100644 index 0000000000..eb59332fe1 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx differ