onnxruntime/orttraining
Vincent Wang 6c63c1c9ee
Multiple Gather to Split Fusion (#13095)
For below code in some transformers models:
```
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
```

The exported graph will contains 3 Gather nodes, currently ORT's
GatherGrad CUDA implementation is slow. This pattern can be fused to use
one Split, so that we can launch less kernels for the compute, the perf
of Split/Concat (for grad) is also better than Gather/GatherGrad.

In a real example, one GatherGrad will take 15ms and there are 3 for
each layer in the graph, after the fusion, one Concat takes only 35us.
The total time of a step is improved from 1.5s to 0.4s.
2022-09-29 11:09:57 +08:00
..
orttraining Multiple Gather to Split Fusion (#13095) 2022-09-29 11:09:57 +08:00
pytorch_frontend_examples Set black's target version (#11370) 2022-04-27 14:52:19 -07:00
tools Remove unused orttraining amd dockerfiles and scripts (#12707) 2022-09-02 18:43:21 -07:00