mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-07 00:13:17 +00:00
### Description Upgrade cutlass to 3.5 to fix build errors using CUDA 12.4 or 12.5 in Windows - [x] Upgrade cutlass to 3.5.0. - [x] Fix flash attention build error with latest cutlass header files and APIs. This fix is provided by @wangyems. - [x] Update efficient attention to use new cutlass fmha interface. - [x] Patch cutlass to fix `hrsqrt` not found error for sm < 53. - [x] Disable TF32 Staged Accumulation to fix blkq4_fp16_gemm_sm80_test build error for cuda 11.8 to 12.3. - [x] Disable TRT 10 deprecate warnings. The following are not included in this PR: * TRT provider replaces the deprecated APIs. * Fix blkq4_fp16_gemm_sm80_test build error for cuda 12.4 or 12.5. This test is not built by default unless you add `--cmake_extra_defines onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON` in build command. To integrate to rel-1.18.1: Either bring in other changes (like onnx 1.16.1), or generate manifest and upload a new ONNX Runtime Build Time Deps artifact based on rel-1.18.1. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/19891 https://github.com/microsoft/onnxruntime/issues/20924 https://github.com/microsoft/onnxruntime/issues/20953
25 lines
801 B
Diff
25 lines
801 B
Diff
diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
|
|
index 964d2ff3..b366bc14 100644
|
|
--- a/include/cutlass/functional.h
|
|
+++ b/include/cutlass/functional.h
|
|
@@ -39,6 +39,7 @@
|
|
#include "cutlass/numeric_types.h"
|
|
|
|
#include <cuda_runtime.h>
|
|
+#include <cuda_fp16.h>
|
|
|
|
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
|
#include <mma.h>
|
|
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
|
|
CUTLASS_HOST_DEVICE
|
|
half_t operator()(half_t const &lhs) const {
|
|
#if defined(__CUDA_ARCH__)
|
|
+#if (__CUDA_ARCH__ >= 530)
|
|
auto result = hrsqrt(reinterpret_cast<__half const &>(lhs));
|
|
return reinterpret_cast<half_t const &>(result);
|
|
+#else
|
|
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
|
|
+#endif
|
|
#else
|
|
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
|
|
#endif
|