Commit graph

2 commits

Author SHA1 Message Date
Tianlei Wu
fbc3927231
[CUDA] cuDNN Flash Attention (#21629)
### Description
- [x] Add cuDNN flash attention using cudnn frontend, and enable it in
MultiHeadAttention operator.
- [x] Support attention mask.
- [x] Support attention bias.
- [x] Update tests and benchmark script.

The cuDNN SDPA is disabled by default. To enable it, need the following:
(1) Requires cuDNN 9.3 or newer version installed.
(2) Set an environment variable `ORT_ENABLE_CUDNN_FLASH_ATTENTION=1` or
set `sdpa_kernel=8` cuda provider option to enable it.
(3) Only works on devices with compute capability >= 8.0.

Note that some combinations of parameters might be rejected due to
limited support of head dimension or sequence lengths.

Future Works:
(1) FP8 and BF16 APIs.  Currently, only API for FP16 are exposed.
(2) Add API to support ragged batching (padding removed in inputs).
(3) Support other input formats (like QKV_BS3NH).
(4) Currently, q are converted to BSNH, k/v are converted to either BSNH
or BNSH format. May do some experiment to see whether converting q to
BNSH could be better in some case.

### Example Benchmark Results on H100

The following tests are on FP16 MultiHeadAttention operator without
attention mask and attention bias.

#### Test Setting 1
batch_size | sequence_length | past_sequence_length | num_heads |
head_size
-- | -- | -- | -- | --
16 | 256 | 0 | 32 | 128

format | average_latency | tflops | kernel
-- | -- | -- | --
Q,K,V (BNSH) | 0.000075 | 229.5 | torch:flash
Q,K,V (BNSH) | 0.000119 | 144.8 | torch:efficient
Q,K,V (BNSH) | 0.000224 | 76.5 | torch:math
Q,K,V (BSNH) | 0.000075 | 227.8 | ort:cudnn
Q,K,V (BSNH) | 0.000094 | 182.8 | ort:flash
Q,K,V (BSNH) | 0.000138 | 124.7 | ort:efficient
Q,K,V (BSNH) | 0.000438 | 39.3 | ort:math
Q,KV | 0.000129 | 133.0 | ort:cudnn
Q,KV | 0.000151 | 114.1 | ort:flash
Q,KV | 0.000194 | 88.5 | ort:efficient
QKV | 0.000154 | 111.8 | ort:cudnn
QKV | 0.000175 | 98.0 | ort:flash
QKV | 0.000217 | 79.0 | ort:efficient

#### Test Setting 2

batch_size | sequence_length | past_sequence_length | num_heads |
head_size
-- | -- | -- | -- | --
16 | 512 | 0 | 16 | 64

format | average_latency | tflops | kernel
-- | -- | -- | --
Q,K,V (BNSH) | 0.000069 | 249.2 | torch:flash
Q,K,V (BNSH) | 0.000141 | 121.7 | torch:efficient
Q,K,V (BNSH) | 0.000294 | 58.5 | torch:math
Q,K,V (BSNH) | 0.000077 | 221.7 | ort:cudnn
Q,K,V (BSNH)  | 0.000087 | 196.6 | ort:flash
Q,K,V (BSNH)  | 0.000163 | 105.6 | ort:efficient
Q,K,V (BSNH)  | 0.000651 | 26.4 | ort:math
Q,KV | 0.000103 | 167.1 | ort:cudnn
Q,KV | 0.000117 | 146.3 | ort:flash
Q,KV | 0.000192 | 89.6 | ort:efficient
QKV | 0.000113 | 151.5 | ort:cudnn
QKV | 0.000128 | 134.7 | ort:flash
QKV | 0.000201 | 85.3 | ort:efficient
2024-08-20 08:50:22 -07:00
Julius Tischbein
1391354265
Adding CUDNN Frontend and use for CUDA NN Convolution (#19470)
### Description
Added CUDNN Frontend and used it for NHWC convolutions, and optionally
fuse activation.

#### Backward compatible 
- For model existed with FusedConv, model can still run. 
- If ORT is built with cuDNN 8, cuDNN frontend will not be built into
binary. Old kernels (using cudnn backend APIs) are used.

#### Major Changes
- For cuDNN 9, we will enable cudnn frontend to fuse convolution and
bias when a provider option `fuse_conv_bias=1`.
- Remove the fusion of FusedConv from graph transformer for CUDA
provider, so there will not be FusedConv be added to graph for CUDA EP
in the future.
- Update cmake files regarding to cudnn settings. The search order of
CUDNN installation in build are like the following:
  * environment variable `CUDNN_PATH`
* `onnxruntime_CUDNN_HOME` cmake extra defines. If a build starts from
build.py/build.sh, user can pass it through `--cudnn_home` parameter, or
by environment variable `CUDNN_HOME` if `--cudnn_home` not used.
* cudnn python package installation directory like
python3.xx/site-packages/nvidia/cudnn
  * CUDA installation path

#### Potential Issues

- If ORT is built with cuDNN 8, FusedConv fusion is no longer done
automatically, so some model might have performance regression. If user
still wants FusedConv operator for performance reason, they can still
have multiple ways to walkaround: like use older version of onnxruntime;
or use older version of ORT to save optimized onnx, then run with latest
version of ORT. We believe that majority users have moved to cudnn 9
when 1.20 release (since the default in ORT and PyTorch is cudnn 9 for 3
months when 1.20 release), so the impact is small.
- cuDNN graph uses TF32 by default, and user cannot disable TF32 through
the use_tf32 cuda provider option. If user encounters accuracy issue
(like in testing), user has to set environment variable
`NVIDIA_TF32_OVERRIDE=0` to disable TF32. Need update the document of
use_tf32 later.

#### Follow ups
This is one of PRs that target to enable NHWC convolution in CUDA EP by
default if device supports it. There are other changes will follow up to
make it possible.
(1) Enable `prefer_nhwc` by default for device with sm >= 70. 
(2) Change `fuse_conv_bias=1` by default after more testing.
(3) Add other NHWC operators (like Resize or UpSample).

### Motivation and Context

The new CUDNN Frontend library provides the functionality to fuse
operations and provides new heuristics for kernel selection. Here it
fuses the convolution with the pointwise bias operation. On the [NVIDIA
ResNet50](https://pytorch.org/hub/nvidia_deeplearningexamples_resnet50/)
we get a performance boost from 49.1144 ms to 42.4643 ms per inference
on a 2560x1440 input (`onnxruntime_perf_test -e cuda -I -q -r 100-d 1 -i
'prefer_nhwc|1' resnet50.onnx`).

---------

Co-authored-by: Tianlei Wu <tlwu@microsoft.com>
Co-authored-by: Maximilian Mueller <maximilianm@nvidia.com>
2024-08-02 15:16:42 -07:00