Resolve crash in MatMul optimization (#5551)

* check pointer before referencing

* add test case

* switch to ASSERT_EQ
This commit is contained in:
RandySheriffH 2020-10-21 13:18:19 -07:00 committed by GitHub
parent 5802fe1699
commit d220c9f950
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 0 deletions

View file

@ -88,6 +88,9 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as
// GEMM only supports unidirectional broadcast on the bias input C
if (!gemm_input_defs.back()->Shape()) {
continue;
}
const auto& bias_shape = *gemm_input_defs.back()->Shape();
const auto& M = matmul_output.Shape()->dim()[0];
const auto& N = matmul_output.Shape()->dim()[1];

View file

@ -681,6 +681,23 @@ TEST_F(GraphTransformationTests, MatMulAddFusion_NotBroadcastable) {
ASSERT_TRUE(op_to_count["Gemm"] == 0);
}
TEST_F(GraphTransformationTests, MatMulAddFusion_MissingShape) {
auto model_uri = MODEL_FOLDER "matmul_add_fusion/matmul_add_missing_shape.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(onnxruntime::make_unique<MatMulAddFusion>(), TransformerLevel::Level1);
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_EQ(op_to_count["MatMul"], 1);
ASSERT_EQ(op_to_count["Add"], 1);
ASSERT_EQ(op_to_count["Gemm"], 0);
}
#ifndef DISABLE_CONTRIB_OPS
TEST_F(GraphTransformationTests, Gemm_Relu_three_input) {
auto model_uri = MODEL_FOLDER "matmul_add_fusion/3Input/gemm_relu.onnx";

View file

@ -0,0 +1,19 @@
:j

A
BS"MatMul

S
CD"AddgraphZ
A


Z
B


Z
C
b
D
B