mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Resolve crash in MatMul optimization (#5551)
* check pointer before referencing * add test case * switch to ASSERT_EQ
This commit is contained in:
parent
5802fe1699
commit
d220c9f950
3 changed files with 39 additions and 0 deletions
|
|
@ -88,6 +88,9 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
|
||||
// valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as
|
||||
// GEMM only supports unidirectional broadcast on the bias input C
|
||||
if (!gemm_input_defs.back()->Shape()) {
|
||||
continue;
|
||||
}
|
||||
const auto& bias_shape = *gemm_input_defs.back()->Shape();
|
||||
const auto& M = matmul_output.Shape()->dim()[0];
|
||||
const auto& N = matmul_output.Shape()->dim()[1];
|
||||
|
|
|
|||
|
|
@ -681,6 +681,23 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_NotBroadcastable) {
|
|||
ASSERT_TRUE(op_to_count["Gemm"] == 0);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, MatMulAddFusion_MissingShape) {
|
||||
auto model_uri = MODEL_FOLDER "matmul_add_fusion/matmul_add_missing_shape.onnx";
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<MatMulAddFusion>(), TransformerLevel::Level1);
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["MatMul"], 1);
|
||||
ASSERT_EQ(op_to_count["Add"], 1);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 0);
|
||||
}
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
TEST_F(GraphTransformationTests, Gemm_Relu_three_input) {
|
||||
auto model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/gemm_relu.onnx";
|
||||
|
|
|
|||
19
onnxruntime/test/testdata/transform/matmul_add_fusion/matmul_add_missing_shape.onnx
vendored
Normal file
19
onnxruntime/test/testdata/transform/matmul_add_fusion/matmul_add_missing_shape.onnx
vendored
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
:j
|
||||
|
||||
A
|
||||
BS"MatMul
|
||||
|
||||
S
|
||||
CD"AddgraphZ
|
||||
A
|
||||
|
||||
|
||||
Z
|
||||
B
|
||||
|
||||
|
||||
Z
|
||||
C
|
||||
b
|
||||
D
|
||||
B
|
||||
Loading…
Reference in a new issue