diff --git a/onnxruntime/core/optimizer/matmul_add_fusion.cc b/onnxruntime/core/optimizer/matmul_add_fusion.cc index a0c392ef75..33bc507990 100644 --- a/onnxruntime/core/optimizer/matmul_add_fusion.cc +++ b/onnxruntime/core/optimizer/matmul_add_fusion.cc @@ -43,10 +43,13 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) auto matmul_input_defs = matmul_node.MutableInputDefs(); auto add_input_defs = add_node.MutableInputDefs(); - // Gemm only support float, so the inputs of MatMul + // Gemm requires that inputs be the same data type and both floating point (float32/float16). auto matmul_type = matmul_input_defs[0]->Type(); auto add_type = add_input_defs[0]->Type(); - if ((*matmul_type) != "tensor(float)" || (*add_type) != "tensor(float)") { + if ((*matmul_type) != (*add_type)) { + continue; + } + if ((*matmul_type) != "tensor(float)" && (*matmul_type) != "tensor(float16)") { continue; }