From bc296c706ebd69e02c02f77f319710c984633266 Mon Sep 17 00:00:00 2001 From: ytaous <4484531+ytaous@users.noreply.github.com> Date: Thu, 14 Apr 2022 21:54:04 -0700 Subject: [PATCH] MatMulScaleFusion - handling scale input (#11121) * scale input * more condition check * alternative * per comments * fix comments Co-authored-by: Ethan Tao --- .../core/optimizer/matmul_scale_fusion.cc | 5 +++ .../test/optimizer/graph_transform_test.cc | 12 ++++++ .../transform/fusion/matmul_scale_gen.py | 39 +++++++++++++++++- .../fusion/matmul_scale_with_scale_input.onnx | Bin 0 -> 283 bytes 4 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx diff --git a/onnxruntime/core/optimizer/matmul_scale_fusion.cc b/onnxruntime/core/optimizer/matmul_scale_fusion.cc index 6f3d3414ca..2c43f5ab12 100644 --- a/onnxruntime/core/optimizer/matmul_scale_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_scale_fusion.cc @@ -131,6 +131,11 @@ std::vector GetInputNodeMerges( ORT_ENFORCE(input_node.InputDefs().size() == 2 && scale_and_index->second < 2); const int to_scale_index = 1 - scale_and_index->second; + // check if the non-scale input is scalar + const auto non_scale_node_arg = input_node.InputDefs()[to_scale_index]; + const auto* shape = non_scale_node_arg->Shape(); + if (shape == nullptr || shape->dim_size() == 0) continue; + input_node_merges.push_back( {input_edge, scale_and_index->first, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 591470fdd6..56880054b3 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4727,6 +4727,18 @@ TEST_F(GraphTransformationTests, MatMulScaleFusionUnsupportedInputType) { {kCpuExecutionProvider}); } +TEST_F(GraphTransformationTests, MatMulScaleFusionWithScaleInput) { + TestMatMulScaleFusion( + MODEL_FOLDER "fusion/matmul_scale_with_scale_input.onnx", *logger_, + [](const Graph&, + const std::map&, + std::map transformed_op_counts) { + EXPECT_EQ(transformed_op_counts["Mul"], 1); + EXPECT_EQ(transformed_op_counts["MatMul"], 1); + EXPECT_EQ(transformed_op_counts["com.microsoft.FusedMatMul"], 0); + }); +} + #if defined(USE_CUDA) || defined(USE_ROCM) TEST_F(GraphTransformationTests, IsInfReduceSum_Test) { auto model_uri = MODEL_FOLDER "fusion/isinf_reducesum.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py b/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py index 89e6baec71..850ff6bc22 100644 --- a/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py @@ -16,7 +16,7 @@ opsets = [onnxdomain, msdomain] scale_value = 3.0 -def save(model_path, nodes, inputs, outputs, initializers): +def save(model_path, nodes, inputs, outputs, initializers, opsets=opsets): graph = helper.make_graph( nodes, "MatMulScaleTest", @@ -218,3 +218,40 @@ def gen_int32(model_path): gen_int32("matmul_scale_int32.onnx") + + +def gen_scale_input(model_path): + + nodes = [ + helper.make_node( + "Mul", ["input_0", "scale"], ["scaled_input_0"], + "scale input_0"), + helper.make_node( + "MatMul", ["scaled_input_0", "input_1"], ["output_0"], + "MatMul input_0 and input_1"), + ] + + initializers = [ + helper.make_tensor("scale", TensorProto.FLOAT, [1], [1.0]) + ] + + inputs = [ + helper.make_tensor_value_info( + "input_0", TensorProto.FLOAT, []), + helper.make_tensor_value_info( + "input_1", TensorProto.FLOAT, [1, 'K']), + ] + + outputs = [ + helper.make_tensor_value_info( + "output_0", TensorProto.FLOAT, ['K']), + ] + + onnxdomain = OperatorSetIdProto() + onnxdomain.version = 14 + onnxdomain.domain = "" + opsets = [onnxdomain, msdomain] + save(model_path, nodes, inputs, outputs, initializers, opsets) + + +gen_scale_input("matmul_scale_with_scale_input.onnx") \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_scale_with_scale_input.onnx new file mode 100644 index 0000000000000000000000000000000000000000..eb59332fe14853ff65dece88e3d820be2697939f GIT binary patch literal 283 zcmZ8by9&ZU5X@;3&y~Ek5v~d;1tkG18zGH_SZQS=F-8a^camJd(*N}*JoB)y+MUPF z4yEV{&YxAoWnN}WnbmUhCa{J{&L54ugkaQZQe41?=F@HKasH zMI+lh#tw;8=Hb#$_^scg!0a@1x_SJ5KB-Czj}n&zx~9FhaaZW<4dEf0pi_d{nGIBg r4p@qs!iW+WBiUc)@T-6`m@3g|g;)p8wPO#gRFtbSPb*Q2yL8D1z8Fn9 literal 0 HcmV?d00001