onnxruntime/orttraining
rui-ren a67e692546
add GatherSliceToSplitFusion and Unittest (#19218)
### Multi Query Attention Optimization

in multi-query attention
```
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
```
which can be optimized to 
```
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
(query, key, value) = fused_qkv.split([self.num_heads, 1, 1], dim=2)
return query, key, value
```

this optimization can be validated from nsight profiling and perf
benchmarking.
   
<img width="545" alt="image"
src="https://github.com/microsoft/onnxruntime/assets/15321482/cefcd061-4a01-4aaf-a008-8e265f7f63e9">

As such, This PR is to Optimize the `Gather/Gather/Slice` Ops to `Split`
Kernel.

### Optimization Target
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

As 2 `Gather` and 1 `Slice` Kernels are time consuming for backward
prop, it would be efficient to use 1 `Split` Kernel


###  Example

- Before Fusion
<img width="419" alt="image"
src="https://github.com/microsoft/onnxruntime/assets/15321482/17410319-57ea-4176-afd4-1efdcd3fdbae">
 
- After Fusion
<img width="424" alt="image"
src="https://github.com/microsoft/onnxruntime/assets/15321482/f1ee1582-96d4-45f4-8778-49d1f3fd370a">

### Perf Gain
After the optimization, there will have **~7%** perf gain. 

> The `Transpose` Kernel can be fused too, will update it in next PR.
However, after testing Transponse Ops fusion on Falcon model, there is
no perf gain. Will not create a new PR.

---------

Co-authored-by: ruiren <ruiren@microsoft.com>
2024-02-14 15:07:56 -08:00
..
orttraining add GatherSliceToSplitFusion and Unittest (#19218) 2024-02-14 15:07:56 -08:00
tools Bump ruff linter to 0.2.1 (#19471) 2024-02-08 16:08:27 -08:00