mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Upgrade cutlass to 3.5.1 and cudnn frontend to 1.7.0 (#22316)
### Description Upgrade cutlass to 3.5.1 Upgrade cudnn_frontend to 1.7.0
This commit is contained in:
parent
f25f3868a7
commit
f3f33bfa05
7 changed files with 1345 additions and 108 deletions
|
|
@ -296,7 +296,7 @@
|
|||
"component": {
|
||||
"type": "git",
|
||||
"git": {
|
||||
"commitHash": "7d49e6c7e2f8896c47f586706e67e1fb215529dc",
|
||||
"commitHash": "f7b19de32c5d1f3cedfc735c2849f12b537522ee",
|
||||
"repositoryUrl": "https://github.com/NVIDIA/cutlass.git"
|
||||
},
|
||||
"comments": "cutlass"
|
||||
|
|
@ -346,7 +346,7 @@
|
|||
"component": {
|
||||
"type": "git",
|
||||
"git": {
|
||||
"commitHash": "98ca4e1941fe3263f128f74f10063a3ea35c7019",
|
||||
"commitHash": "de355c7094af70467f2b264f531ab5c5f4401c42",
|
||||
"repositoryUrl": "https://github.com/NVIDIA/cudnn-frontend.git"
|
||||
},
|
||||
"comments": "cudnn_frontend"
|
||||
|
|
|
|||
|
|
@ -53,11 +53,11 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/ca678952a9a8eaa6de112
|
|||
re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88
|
||||
safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac
|
||||
tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381
|
||||
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.0.zip;ae038931b9fc2c416c17d9cda91d9706b343f56d
|
||||
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.1.zip;e49b2b964163d27765a5002d210a2f3c73771835
|
||||
utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156
|
||||
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
|
||||
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557
|
||||
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
|
||||
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029
|
||||
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.7.0.zip;d0753d8d5b39947ca0729d7773cb84653a129eb1
|
||||
dawn;https://github.com/google/dawn/archive/511eb80847afe6bded34ec491a38d5d78ba2d604.zip;c493f5aca5586f6634e25d0121c85df71189fb99
|
||||
kleidiai;https://gitlab.arm.com/kleidi/kleidiai/-/archive/v0.2.0/kleidiai-v0.2.0.zip;B1E3173992FD91F20DB904AB77D6E901778C2681
|
||||
|
|
|
|||
1
cmake/external/cutlass.cmake
vendored
1
cmake/external/cutlass.cmake
vendored
|
|
@ -3,7 +3,6 @@ FetchContent_Declare(
|
|||
cutlass
|
||||
URL ${DEP_URL_cutlass}
|
||||
URL_HASH SHA1=${DEP_SHA1_cutlass}
|
||||
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass_3.5.0.patch
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(cutlass)
|
||||
|
|
|
|||
|
|
@ -1,100 +0,0 @@
|
|||
diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
|
||||
index 4c80f549..5ad610c8 100644
|
||||
--- a/examples/41_fused_multi_head_attention/kernel_forward.h
|
||||
+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
|
||||
@@ -189,6 +189,7 @@ struct AttentionKernel {
|
||||
|
||||
// Scale
|
||||
accum_t scale = 0.0;
|
||||
+ accum_t softcap = 0.0;
|
||||
|
||||
// Dimensions/strides
|
||||
int32_t head_dim = 0;
|
||||
@@ -221,6 +222,8 @@ struct AttentionKernel {
|
||||
int32_t num_batches = 0;
|
||||
int32_t num_heads = 0;
|
||||
|
||||
+ bool use_smooth_softmax = false;
|
||||
+
|
||||
// dropout
|
||||
bool use_dropout = false;
|
||||
unsigned long long dropout_batch_head_rng_offset = 0;
|
||||
@@ -818,6 +821,15 @@ struct AttentionKernel {
|
||||
accum =
|
||||
cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
|
||||
}
|
||||
+
|
||||
+ // apply softcap if applicable
|
||||
+ if (p.softcap > 0.0) {
|
||||
+ accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(1.0 / p.softcap, accum);
|
||||
+ for (int i = 0; i < accum.size(); ++i) {
|
||||
+ accum[i] = cutlass::fast_tanh(accum[i]);
|
||||
+ }
|
||||
+ accum = cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.softcap, accum);
|
||||
+ }
|
||||
|
||||
// apply attention bias if applicable
|
||||
if (kSupportsBias && p.attn_bias_ptr != nullptr) {
|
||||
@@ -897,7 +909,8 @@ struct AttentionKernel {
|
||||
p.num_keys - iter_key_start,
|
||||
iter_key_start == 0,
|
||||
iteratorC_tile_offset,
|
||||
- kSupportsBias ? 1.0f : p.scale);
|
||||
+ kSupportsBias ? 1.0f : p.scale,
|
||||
+ p.use_smooth_softmax);
|
||||
|
||||
// Output results to shared-memory
|
||||
int warp_idx_mn_0 = my_warp_id %
|
||||
@@ -1166,7 +1179,8 @@ struct AttentionKernel {
|
||||
int max_col,
|
||||
bool is_first,
|
||||
typename WarpIteratorC::TensorCoord const& tile_offset,
|
||||
- float scaling) {
|
||||
+ float scaling,
|
||||
+ bool use_smooth_softmax) {
|
||||
/* Iterates on the accumulator and corresponding position on result matrix
|
||||
|
||||
(1) Update `mi[r]` to the max value of the row `r`
|
||||
@@ -1257,7 +1271,7 @@ struct AttentionKernel {
|
||||
accum_t mi_row, total_row;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
- [&](int accum_m) { mi_row = mi[accum_m]; },
|
||||
+ [&](int accum_m) { mi_row = mi[accum_m];},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag[idx] =
|
||||
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
|
||||
@@ -1294,7 +1308,7 @@ struct AttentionKernel {
|
||||
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
|
||||
total_row += addition_storage[id + kQueriesPerBlock * i];
|
||||
}
|
||||
- s_prime[id] = total_row;
|
||||
+ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
|
||||
}
|
||||
}
|
||||
|
||||
diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
|
||||
index 964d2ff3..676ba768 100644
|
||||
--- a/include/cutlass/functional.h
|
||||
+++ b/include/cutlass/functional.h
|
||||
@@ -39,6 +39,7 @@
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
+#include <cuda_fp16.h>
|
||||
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
#include <mma.h>
|
||||
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
half_t operator()(half_t const &lhs) const {
|
||||
#if defined(__CUDA_ARCH__)
|
||||
+#if (__CUDA_ARCH__ >= 530)
|
||||
auto result = hrsqrt(reinterpret_cast<__half const &>(lhs));
|
||||
return reinterpret_cast<half_t const &>(result);
|
||||
+#else
|
||||
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
|
||||
+#endif
|
||||
#else
|
||||
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
|
||||
#endif
|
||||
|
|
@ -10,7 +10,7 @@
|
|||
#endif
|
||||
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
#include "41_fused_multi_head_attention/kernel_forward.h"
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
|
|||
1338
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h
Normal file
1338
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -11,7 +11,7 @@ steps:
|
|||
packageType: upack
|
||||
feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
|
||||
definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0'
|
||||
version: 1.0.188
|
||||
version: 1.0.191
|
||||
downloadPath: $(Build.BinariesDirectory)/deps
|
||||
|
||||
# The private ADO project
|
||||
|
|
@ -22,7 +22,7 @@ steps:
|
|||
packageType: upack
|
||||
feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325'
|
||||
definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a'
|
||||
version: 1.0.188
|
||||
version: 1.0.191
|
||||
downloadPath: $(Build.BinariesDirectory)/deps
|
||||
|
||||
# You can add more ADO accounts at here.
|
||||
|
|
|
|||
Loading…
Reference in a new issue