onnxruntime/cmake/patches/cutlass/cutlass_3.5.0.patch
Tianlei Wu b3fc9b5a0e
[CUDA] upgrade cutlass to 3.5.0 (#20940)
### Description
Upgrade cutlass to 3.5 to fix build errors using CUDA 12.4 or 12.5 in
Windows
- [x] Upgrade cutlass to 3.5.0.
- [x] Fix flash attention build error with latest cutlass header files
and APIs. This fix is provided by @wangyems.
- [x] Update efficient attention to use new cutlass fmha interface.
- [x] Patch cutlass to fix `hrsqrt` not found error for sm < 53.
- [x] Disable TF32 Staged Accumulation to fix blkq4_fp16_gemm_sm80_test
build error for cuda 11.8 to 12.3.
- [x] Disable TRT 10 deprecate warnings. 

The following are not included in this PR:
* TRT provider replaces the deprecated APIs.
* Fix blkq4_fp16_gemm_sm80_test build error for cuda 12.4 or 12.5. This
test is not built by default unless you add `--cmake_extra_defines
onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON` in build command.

To integrate to rel-1.18.1: Either bring in other changes (like onnx
1.16.1), or generate manifest and upload a new ONNX Runtime Build Time
Deps artifact based on rel-1.18.1.

### Motivation and Context
https://github.com/microsoft/onnxruntime/issues/19891
https://github.com/microsoft/onnxruntime/issues/20924
https://github.com/microsoft/onnxruntime/issues/20953
2024-06-11 13:32:15 -07:00

25 lines
801 B
Diff

diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
index 964d2ff3..b366bc14 100644
--- a/include/cutlass/functional.h
+++ b/include/cutlass/functional.h
@@ -39,6 +39,7 @@
#include "cutlass/numeric_types.h"
#include <cuda_runtime.h>
+#include <cuda_fp16.h>
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include <mma.h>
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
CUTLASS_HOST_DEVICE
half_t operator()(half_t const &lhs) const {
#if defined(__CUDA_ARCH__)
+#if (__CUDA_ARCH__ >= 530)
auto result = hrsqrt(reinterpret_cast<__half const &>(lhs));
return reinterpret_cast<half_t const &>(result);
+#else
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
+#endif
#else
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
#endif