Enable float16 MatMul+Add -> GEMM fusion for performance boost (#1506)

This commit is contained in:
Dwayne Robinson 2019-07-29 15:18:02 -07:00 committed by Changming Sun
parent cf5a4b5856
commit cf73f63cb9

View file

@ -43,10 +43,13 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level)
auto matmul_input_defs = matmul_node.MutableInputDefs();
auto add_input_defs = add_node.MutableInputDefs();
// Gemm only support float, so the inputs of MatMul
// Gemm requires that inputs be the same data type and both floating point (float32/float16).
auto matmul_type = matmul_input_defs[0]->Type();
auto add_type = add_input_defs[0]->Type();
if ((*matmul_type) != "tensor(float)" || (*add_type) != "tensor(float)") {
if ((*matmul_type) != (*add_type)) {
continue;
}
if ((*matmul_type) != "tensor(float)" && (*matmul_type) != "tensor(float16)") {
continue;
}