mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
Enable float16 MatMul+Add -> GEMM fusion for performance boost (#1506)
This commit is contained in:
parent
cf5a4b5856
commit
cf73f63cb9
1 changed files with 5 additions and 2 deletions
|
|
@ -43,10 +43,13 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level)
|
|||
auto matmul_input_defs = matmul_node.MutableInputDefs();
|
||||
auto add_input_defs = add_node.MutableInputDefs();
|
||||
|
||||
// Gemm only support float, so the inputs of MatMul
|
||||
// Gemm requires that inputs be the same data type and both floating point (float32/float16).
|
||||
auto matmul_type = matmul_input_defs[0]->Type();
|
||||
auto add_type = add_input_defs[0]->Type();
|
||||
if ((*matmul_type) != "tensor(float)" || (*add_type) != "tensor(float)") {
|
||||
if ((*matmul_type) != (*add_type)) {
|
||||
continue;
|
||||
}
|
||||
if ((*matmul_type) != "tensor(float)" && (*matmul_type) != "tensor(float16)") {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue