mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
### Multi Query Attention Optimization in multi-query attention ``` batch_size, seq_length, three_times_hidden_size = fused_qkv.shape fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] ``` which can be optimized to ``` batch_size, seq_length, three_times_hidden_size = fused_qkv.shape fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) (query, key, value) = fused_qkv.split([self.num_heads, 1, 1], dim=2) return query, key, value ``` this optimization can be validated from nsight profiling and perf benchmarking. <img width="545" alt="image" src="https://github.com/microsoft/onnxruntime/assets/15321482/cefcd061-4a01-4aaf-a008-8e265f7f63e9"> As such, This PR is to Optimize the `Gather/Gather/Slice` Ops to `Split` Kernel. ### Optimization Target <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> As 2 `Gather` and 1 `Slice` Kernels are time consuming for backward prop, it would be efficient to use 1 `Split` Kernel ### Example - Before Fusion <img width="419" alt="image" src="https://github.com/microsoft/onnxruntime/assets/15321482/17410319-57ea-4176-afd4-1efdcd3fdbae"> - After Fusion <img width="424" alt="image" src="https://github.com/microsoft/onnxruntime/assets/15321482/f1ee1582-96d4-45f4-8778-49d1f3fd370a"> ### Perf Gain After the optimization, there will have **~7%** perf gain. > The `Transpose` Kernel can be fused too, will update it in next PR. However, after testing Transponse Ops fusion on Falcon model, there is no perf gain. Will not create a new PR. --------- Co-authored-by: ruiren <ruiren@microsoft.com> |
||
|---|---|---|
| .. | ||
| orttraining | ||
| tools | ||