mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
For below code in some transformers models: ``` fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] ``` The exported graph will contains 3 Gather nodes, currently ORT's GatherGrad CUDA implementation is slow. This pattern can be fused to use one Split, so that we can launch less kernels for the compute, the perf of Split/Concat (for grad) is also better than Gather/GatherGrad. In a real example, one GatherGrad will take 15ms and there are 3 for each layer in the graph, after the fusion, one Concat takes only 35us. The total time of a step is improved from 1.5s to 0.4s. |
||
|---|---|---|
| .. | ||
| orttraining | ||
| pytorch_frontend_examples | ||
| tools | ||