From d220c9f9504131132a1e2a314a1d4ec4e38676cb Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Wed, 21 Oct 2020 13:18:19 -0700 Subject: [PATCH] Resolve crash in MatMul optimization (#5551) * check pointer before referencing * add test case * switch to ASSERT_EQ --- .../core/optimizer/matmul_add_fusion.cc | 3 +++ .../test/optimizer/graph_transform_test.cc | 17 +++++++++++++++++ .../matmul_add_missing_shape.onnx | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 onnxruntime/test/testdata/transform/matmul_add_fusion/matmul_add_missing_shape.onnx diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index de83c32417..6ed0b3734d 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -88,6 +88,9 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, // valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as // GEMM only supports unidirectional broadcast on the bias input C + if (!gemm_input_defs.back()->Shape()) { + continue; + } const auto& bias_shape = *gemm_input_defs.back()->Shape(); const auto& M = matmul_output.Shape()->dim()[0]; const auto& N = matmul_output.Shape()->dim()[1]; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index a4e155783b..85c23aae44 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -681,6 +681,23 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_NotBroadcastable) { ASSERT_TRUE(op_to_count["Gemm"] == 0); } +TEST_F(GraphTransformationTests, MatMulAddFusion_MissingShape) { + auto model_uri = MODEL_FOLDER "matmul_add_fusion/matmul_add_missing_shape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level1); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Add"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + #ifndef DISABLE_CONTRIB_OPS TEST_F(GraphTransformationTests, Gemm_Relu_three_input) { auto model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/gemm_relu.onnx"; diff --git a/onnxruntime/test/testdata/transform/matmul_add_fusion/matmul_add_missing_shape.onnx b/onnxruntime/test/testdata/transform/matmul_add_fusion/matmul_add_missing_shape.onnx new file mode 100644 index 0000000000..d029c6e784 --- /dev/null +++ b/onnxruntime/test/testdata/transform/matmul_add_fusion/matmul_add_missing_shape.onnx @@ -0,0 +1,19 @@ +:j + +A +BS"MatMul + +S +CD"AddgraphZ +A +  + +Z +B +  + +Z +C +b +D +B \ No newline at end of file