onnxruntime/cmake
Tianlei Wu 9f0fae29e8
[CUDA] Add SparseAttention operator for Phi-3-small (#20216)
### Description
Add CUDA implementation for block sparse attention for Phi-3-small.

Block sparse attention was proposed in [Sparse
Transformers](https://arxiv.org/pdf/1904.10509) by OpenAI, and also
adopted in [BigBird](https://arxiv.org/pdf/2007.14062) with different
sparse layout.

In Phi-3-small, the sparse layout is static, and works with
unidirectional (causal) attention.

Compared to dense attention, the benefit of block sparse is to speed up
both training and inference. It could save memory thus support longer
context length.

- [x] Add operator spec and shape inference
- [x] Symbolic shape inference
- [x] Refactor GroupQueryAttention to expose common kernels for kv cache
concatenation, q/k/v transpose etc.
- [x] Add cuda kernel to convert block mask to CSR format
- [x] Add cuda kernel to generate position ids
- [x] Add compile script and template files to convert triton kernel to
cubin and dispatcher.
- [x] Add triton kernel v1 for prompt
- [x] Add triton kernel v2 for token generation and support padding
- [x] Update IO Binding Helper to allow buffer sharing.
- [x] Test relevance
- [x] Test performance

### Performance
Test in A100-SXM4-80GB with `batch_size=4, num_heads=32,
max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16,
vert_stride=8, num_layout=8`

We compare sparse attention to corresponding GQA with local attention
windows size 1024, or GQA with dense causal.

Average latency in milliseconds (for fused attention kernel used in
prompt prefilling):

seq_len | GQA-Dense | GQA-Local | SparseAttention
-- | -- | -- | --
64 | 0.0465 | 0.0722 | 0.0641
128 | 0.0618 | 0.0787 | 0.0672
256 | 0.1086 | 0.1076 | 0.0943
512 | 0.2535 | 0.2487 | 0.1676
1024 | 0.7042 | 0.7050 | 0.3800
2048 | 2.4125 | 1.9316 | 0.8966
4096 | 8.9346 | 4.5699 | 2.1129
8192 | 40.5401 | 10.3508 | 5.1748

Average latency in milliseconds (for fused attention kernel used in
token generation:

past_seq_len | GQA-Dense | GQA-Local | SparseAttention
-- | -- | -- | --
64 | 0.0186 | 0.0186 | 0.0870
128 | 0.0408 | 0.0466 | 0.1165
256 | 0.0530  | 0.0592 | 0.0988
512 | 0.0445| 0.0447 | 0.1150
1024 | 0.0634  | 0.0640 | 0.1454
2048 | 0.1027 | 0.0637 | 0.1589
4096 | 0.1789 | 0.0631 | 0.1806
8192 | 0.3288 | 0.0655 | 0.2146

We can see that the kernel for token generation still have room to
improve.

#### Limitations
Only support right-side padding and unidirectional attention.

The following are not supported in the first version:
(1) Packed mode like PackedMultiHeadAttention where input has been
removed padding.
(2) paged attention.
(3) bidirectional attention.
(4) GPU compute capacity that is not 8.0, 8.6 and 8.9.
(5) Left side padding.

Some of these limitations will be removed in the future (may be in a new
operator).
2024-04-30 09:06:29 -07:00
..
external upgrade emsdk to 3.1.57 (#20295) 2024-04-19 23:05:18 -07:00
patches Add patch for ONNX 1.16.0 shape inference bug (#20316) 2024-04-17 10:23:22 -07:00
tensorboard
adjust_global_compile_flags.cmake Support visionos build (#20365) 2024-04-23 18:15:07 -07:00
arm64x.cmake Build onnxruntime.dll as arm64x (#18633) 2023-12-06 16:49:00 -08:00
CMakeLists.txt Update regex to match correct pattern. (#20483) 2024-04-29 10:43:31 -07:00
CMakeSettings.json
codeconv.runsettings
deps.txt Integration with ONNX 1.16.0 (#19745) 2024-04-12 09:46:49 -07:00
deps_update_and_upload.py Update google benchmark to 1.8.3. (#19734) 2024-03-01 11:01:58 -08:00
EnableVisualStudioCodeAnalysis.props
gdk_toolchain.cmake
Info.plist.in
libonnxruntime.pc.cmake.in
linux_arm32_crosscompile_toolchain.cmake Add a build validation for Linux ARM64 cross-compile (#18200) 2023-11-08 13:03:18 -08:00
linux_arm64_crosscompile_toolchain.cmake Add a build validation for Linux ARM64 cross-compile (#18200) 2023-11-08 13:03:18 -08:00
maccatalyst_prepare_objects_for_prelink.py Support xcframework for mac catalyst builds. (#19534) 2024-03-20 10:55:19 -07:00
nuget_helpers.cmake
onnxruntime.cmake Support xcframework for mac catalyst builds. (#19534) 2024-03-20 10:55:19 -07:00
onnxruntime_codegen_tvm.cmake
onnxruntime_common.cmake Fix build errors from date/date.h C++20 compatibility (#20139) 2024-04-02 22:10:25 -07:00
onnxruntime_compile_triton_kernel.cmake [CUDA] Add SparseAttention operator for Phi-3-small (#20216) 2024-04-30 09:06:29 -07:00
onnxruntime_config.h.in Enabling c++ 20 in MacOS build (#16187) 2023-09-26 11:27:02 -07:00
onnxruntime_csharp.cmake
onnxruntime_flatbuffers.cmake Rework some external targets to ease building with -DFETCHCONTENT_FULLY_DISCONNECTED=ON (#15323) 2023-04-03 17:45:12 -07:00
onnxruntime_framework.cmake [C#, CPP] Introduce Float16/BFloat16 support and tests for C#, C++ (#16506) 2023-07-14 10:46:52 -07:00
onnxruntime_framework.natvis [C#, CPP] Introduce Float16/BFloat16 support and tests for C#, C++ (#16506) 2023-07-14 10:46:52 -07:00
onnxruntime_fuzz_test.cmake
onnxruntime_graph.cmake [Apple framework] Fix minimal build with training enabled. (#19858) 2024-03-12 11:33:30 -07:00
onnxruntime_ios.toolchain.cmake Support visionos build (#20365) 2024-04-23 18:15:07 -07:00
onnxruntime_java.cmake Update build option for training in java to enable_training_api (#15638) 2023-04-24 11:53:08 -07:00
onnxruntime_java_unittests.cmake Update build option for training in java to enable_training_api (#15638) 2023-04-24 11:53:08 -07:00
onnxruntime_kernel_explorer.cmake [ROCm] TunableOp: Update rocBLAS get_solutions API (since ROCm5.6) (#16657) 2023-07-13 11:20:26 +08:00
onnxruntime_language_interop_ops.cmake
onnxruntime_mlas.cmake Mlas Gemm 4bit avx2, avx512, and avx512vnni kernels (#20163) 2024-04-25 21:30:50 -07:00
onnxruntime_nodejs.cmake Support building Windows CUDA with Ninja (#20176) 2024-04-03 11:19:31 +08:00
onnxruntime_objectivec.cmake Objective C Training API: TrainingSession (#16374) 2023-06-28 09:13:56 -07:00
onnxruntime_opschema_lib.cmake
onnxruntime_optimizer.cmake [ROCm] Fix hipify error: fast_divmod.h: No such file or directory (#19060) 2024-01-10 14:49:19 +08:00
onnxruntime_providers.cmake Add initial support for CoreML ML Program to the CoreML EP. (#19347) 2024-02-15 08:46:03 +10:00
onnxruntime_providers_acl.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_armnn.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_azure.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_cann.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_coreml.cmake Fix Objective-C static analysis warnings. (#20417) 2024-04-24 11:48:29 -07:00
onnxruntime_providers_cpu.cmake Support visionos build (#20365) 2024-04-23 18:15:07 -07:00
onnxruntime_providers_cuda.cmake Enable CUDA EP unit testing on Windows (#20039) 2024-03-27 13:32:36 -07:00
onnxruntime_providers_dml.cmake Delay load dxcore.dll in addition to ext-ms-win-dxcore-l1-1-0.dll (#18913) 2023-12-26 12:33:42 -08:00
onnxruntime_providers_dnnl.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_js.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_migraphx.cmake CUDA EP vs ROCM EP hipify audit (#17776) 2023-10-13 10:13:53 +08:00
onnxruntime_providers_nnapi.cmake Make partitioning utils QDQ aware so it does not break up QDQ node units (#19723) 2024-03-12 10:55:49 +10:00
onnxruntime_providers_openvino.cmake Ort openvino npu 1.17 master (#19966) 2024-03-21 18:44:00 -07:00
onnxruntime_providers_qnn.cmake Make partitioning utils QDQ aware so it does not break up QDQ node units (#19723) 2024-03-12 10:55:49 +10:00
onnxruntime_providers_rknpu.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_rocm.cmake CUDA EP vs ROCM EP hipify audit (#17776) 2023-10-13 10:13:53 +08:00
onnxruntime_providers_tensorrt.cmake [TensorRT] adapt for TRT lib name change after TRT 10 GA (#20445) 2024-04-24 21:46:54 -07:00
onnxruntime_providers_tvm.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_vitisai.cmake [VitisAI] Solve the problem that gsl cannot be found when compiling under linux (#20466) 2024-04-28 20:56:16 -07:00
onnxruntime_providers_webnn.cmake Split onnxruntime_providers.cmake to multiple (#17853) 2023-10-09 20:33:44 -07:00
onnxruntime_providers_xnnpack.cmake Make partitioning utils QDQ aware so it does not break up QDQ node units (#19723) 2024-03-12 10:55:49 +10:00
onnxruntime_pyop.cmake
onnxruntime_python.cmake [qnn ep] include qnn sdk in onnxruntime-qnn python whl (#20485) 2024-04-29 09:44:54 -07:00
onnxruntime_rocm_hipify.cmake [CUDA] Add SparseAttention operator for Phi-3-small (#20216) 2024-04-30 09:06:29 -07:00
onnxruntime_session.cmake added support for cmake "find_package" (#8919) 2023-06-19 22:20:31 -07:00
onnxruntime_snpe_provider.cmake
onnxruntime_training.cmake Triton Codegen for ORTModule (#15831) 2023-07-13 18:17:58 +08:00
onnxruntime_unittests.cmake fix the build issue for Win Arm64 Release build (#20475) 2024-04-25 22:08:19 -07:00
onnxruntime_util.cmake
onnxruntime_visionos.toolchain.cmake Support visionos build (#20365) 2024-04-23 18:15:07 -07:00
onnxruntime_webassembly.cmake upgrade emsdk to 3.1.57 (#20295) 2024-04-19 23:05:18 -07:00
precompiled_header.cmake
riscv64.toolchain.cmake Enable RISC-V 64-bit Cross-Compiling Support for ONNX Runtime on Linux (#19238) 2024-01-24 16:27:05 -08:00
Sdl.ruleset Add a Github workflow for Prefast (#15763) 2023-05-03 11:42:51 -07:00
set_winapi_family_desktop.h
target_delayload.cmake
uwp_stubs.h Run clang-format in CI (#15524) 2023-04-18 09:26:58 -07:00
wcos_rules_override.cmake Stop using apiset in OneCore build: use onecoreuap.lib instead of onecoreuap_apiset.lib (#19632) 2024-02-23 22:31:57 -08:00
winml.cmake [CP] Fix for xfgcheck and Fix WAI ARM64 build (#19634) (#19644) 2024-03-13 17:54:06 -07:00
winml_cppwinrt.cmake
winml_sdk_helpers.cmake
winml_unittests.cmake Update C/C++ dependencies: abseil, date, nsync, googletest, wil, mp11, cpuinfo and safeint (#15470) 2023-09-08 13:35:04 -07:00