onnxruntime/onnxruntime/contrib_ops
Jing Fang 13348c572a
[ARM CPU] hgemm optimized for gqa (#23107)
### Description
Add fp16 kernels for GQA matmul on ARM CPU.
The kernels are mlas hgemm for C = alpha * A x B' + beta * C


### Motivation and Context
Add fp16 support for GQA, speed up the operator and reduce memory usage.

__Token Generation__
| | HGEMM Runtime (ns) | SGEMM Runtime (ns) | Speed-up (%) |

|---------------------------------|--------------------|--------------------|--------------|
| M:1/N:4096/K:4096 | 251551 | 1775905 | 85.84 |
| M:1/N:11008/K:4096 | 892507 | 4649145 | 80.80 |
| M:1/N:4096/K:11008 | 866860 | 3240015 | 73.25 |
| M:1/N:11008/K:11008 | 2631615 |8783877 | 70.04 |

__Prompting__
| | HGEMM Runtime (ns) | SGEMM Runtime (ns) | Speed-up (%) |

|---------------------------------|--------------------|--------------------|--------------|
| M:1024/N:4096/K:4096 | 90508701 | 111283029 | 18.67 |
| M:2048/N:4096/K:4096 | 181307522 | 240211107 | 24.52 |
| M:1024/N:11008/K:4096 | 241120234 | 307707933 | 21.64 |
| M:2048/N:11008/K:4096 | 481091232 | 648921367 | 25.86 |
| M:1024/N:4096/K:11008 | 241736343 | 310129880 | 22.05 |
| M:2048/N:4096/K:11008 | 480456703 | 644814999 | 25.49 |
| M:1024/N:11008/K:11008 | 642121440 | 847925766 | 24.27 |
| M:2048/N:11008/K:11008 | 1276097154 | 1731314509 | 26.29
2025-01-24 15:25:24 -08:00
..
cpu [ARM CPU] hgemm optimized for gqa (#23107) 2025-01-24 15:25:24 -08:00
cuda Stable Diffusion 3.x and Flux Optimization (#22986) 2025-01-14 13:37:58 -08:00
js
rocm
webgpu WIP: Dp4MatMulNBits accuracy level 4 matmul for WebGPU EP (#23365) 2025-01-21 15:46:51 -08:00