diff --git a/CMakeLists.txt b/CMakeLists.txt index 3800fe238cd..c9e1cebdacc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -725,7 +725,10 @@ set(BUILD_ONEDNN_GRAPH OFF) include(cmake/Dependencies.cmake) # Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now -option(USE_FLASH_ATTENTION "Whether to build the flash_attention kernel for scaled dot product attention" OFF) +cmake_dependent_option( + USE_FLASH_ATTENTION + "Whether to build the flash_attention kernel for scaled dot product attention" ON + "USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF) if(USE_FLASH_ATTENTION) ADD_DEFINITIONS(-DUSE_FLASH_ATTENTION) ENDIF() diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/epilogue.h b/aten/src/ATen/native/transformers/cuda/flash_attn/epilogue.h index 65c3180a9c8..2bf4e1eb548 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/epilogue.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/epilogue.h @@ -63,7 +63,7 @@ struct FMHAEpilogue { Element, ElementC, /*ElementsPerAccess=*/4, ThreadblockShape, typename WarpMma::Shape, typename WarpMma::Policy::Operator::Shape, typename OutputTileThreadMap::CompactedThreadMap>; using WarpTileIterator = typename DefaultIterators::WarpTileIterator; - static_assert(WarpTileIterator::kIterations == kIterationsStore); + static_assert(WarpTileIterator::kIterations == kIterationsStore, ""); using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator; using OutputFragment = typename SharedLoadIterator::Fragment; diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha.h index d259280fac5..2bd17da72f7 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha.h @@ -151,4 +151,4 @@ struct Launch_params{ //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_fmha_fprop(Launch_params &launch_params, const bool configure); +TORCH_API void run_fmha_fprop(Launch_params &launch_params, const bool configure); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp index 691465c5354..a8d6110e951 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp @@ -26,6 +26,7 @@ * ******************************************************************************/ +#ifdef USE_FLASH_ATTENTION #include #include #include @@ -241,3 +242,4 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q return result; } } // namespace fmha +#endif diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h index 3dca7e2ac89..226d4ddd2b5 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h @@ -6,6 +6,7 @@ namespace fmha { +TORCH_API std::vector mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h index c4fe1880246..1a41438c662 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h @@ -102,7 +102,7 @@ struct Gemm_Q_K : public Gemm_Q_K_base { static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V; // If V is stored in shared memory, we can't load K using the same shared memory. - static_assert(Kernel_traits::V_IN_REGS); + static_assert(Kernel_traits::V_IN_REGS, ""); static constexpr size_t SMEM_OFFSET_O = Kernel_traits::BYTES_PER_SMEM_Q; static constexpr size_t SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + sizeof(typename Smem_O::SharedStorage); @@ -161,7 +161,7 @@ struct Gemm_Q_K : public Gemm_Q_K_base { static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V; static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS; - static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V); + static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V, ""); static constexpr size_t SMEM_OFFSET_V = Kernel_traits::BYTES_PER_SMEM_Q + (SHARE_SMEM_FOR_K_AND_V ? 0 : Kernel_traits::BYTES_PER_SMEM_K); static constexpr size_t SMEM_OFFSET_O = SMEM_OFFSET_V + Kernel_traits::BYTES_PER_SMEM_V; @@ -298,7 +298,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); // Wind gmem tiles to the correct position. - static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); + static_assert(Cta_tile_p::N % Cta_tile_p::M == 0, ""); const int begin_og = begin; begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin; const int steps_og = steps; @@ -428,7 +428,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i const int warp_idx = threadIdx.x / 32; iter_V.add_tile_offset({kIterationsPV * warp_idx, 0}); typename WarpIteratorV::Fragment frag_v[kIterationsPV]; - static_assert(WarpIteratorV::Fragment::kStorageElements == 4 * Mma_tile_o::MMAS_N || WarpIteratorV::Fragment::kStorageElements == 2 * Mma_tile_o::MMAS_N ); + static_assert(WarpIteratorV::Fragment::kStorageElements == 4 * Mma_tile_o::MMAS_N || WarpIteratorV::Fragment::kStorageElements == 2 * Mma_tile_o::MMAS_N, ""); #pragma unroll for( int ki = 0; ki < kIterationsPV; ++ki ) { iter_V.load(frag_v[ki]); @@ -463,8 +463,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i gemm_q_k(mma_qk, acc_p); typename Smem_O::OutputFragment out[Smem_O::kIterationsStore]; - static_assert(GmemIteratorOAccum::kIterations == Smem_O::kIterationsStore); - static_assert(GmemIteratorO::kIterations == Smem_O::kIterationsStore); + static_assert(GmemIteratorOAccum::kIterations == Smem_O::kIterationsStore, ""); + static_assert(GmemIteratorO::kIterations == Smem_O::kIterationsStore, ""); if (!Is_first) { #pragma unroll for (int iter = 0; iter < GmemIteratorOAccum::kIterations; ++iter) { @@ -536,8 +536,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i softmax.template apply_dropout_16bits(ph0, ph1, params.p_dropout_in_uint16_t); } - static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); - static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); + static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M, ""); + static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N, ""); softmax.pack_noconvert(acc_p); cutlass::NumericArrayConverter convert_p; auto frag_p = convert_p(acc_p); @@ -558,13 +558,13 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Declare the accumulators for the 2nd gemm. WarpMmaPV mma_pv; typename WarpMmaPV::FragmentC acc_o; - static_assert(WarpMmaPV::FragmentC::kElements == Mma_tile_o::MMAS_M * Mma_tile_o::MMAS_N * 8); + static_assert(WarpMmaPV::FragmentC::kElements == Mma_tile_o::MMAS_M * Mma_tile_o::MMAS_N * 8, ""); acc_o.clear(); // For some reason, WarpMmaPV::FragmentA has length K * N * (8|4) instead of just N * (8|4). // We have to first cast frag_p to be array of k x (N * (8|4)), then cast each row to be // an array of WarpMmaPV::FragmentA (which is what mma_pv expects). - static_assert(decltype(frag_p)::kElements == kIterationsPV * Mma_tile_o::MMAS_M * WarpMmaPV::FragmentA::kElements); + static_assert(decltype(frag_p)::kElements == kIterationsPV * Mma_tile_o::MMAS_M * WarpMmaPV::FragmentA::kElements, ""); const auto frag_p_reshaped = reinterpret_cast (&)[kIterationsPV]>(frag_p); #pragma unroll for( int ki = 0; ki < kIterationsPV; ++ki ) { @@ -589,7 +589,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i } softmax.reduce_max_after_sync_(p_max_o, rows); - static_assert(Mma_tile_o::MMAS_M == 1); + static_assert(Mma_tile_o::MMAS_M == 1, ""); for (int jj = 0; jj < kOutputRowsPerThread; jj++) { p_max_o[jj][0] *= params.scale_bmm1; } @@ -601,7 +601,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Make sure the data is in shared memory. __syncthreads(); - static_assert(Mma_tile_o::MMAS_M == 1); + static_assert(Mma_tile_o::MMAS_M == 1, ""); float p_sum_o[kOutputRowsPerThread][Mma_tile_o::MMAS_M]; softmax.reduce_sum_after_sync_(p_sum_o, rows); if (!Is_first) { @@ -625,7 +625,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Load from shared memory. using ArrayTypeO = cutlass::Array; - static_assert(OutputTileThreadMap::kElementsPerAccess * kOutputRowsPerThread == Smem_O::kIterationsStore * Smem_O::OutputFragment::kElements); + static_assert(OutputTileThreadMap::kElementsPerAccess * kOutputRowsPerThread == Smem_O::kIterationsStore * Smem_O::OutputFragment::kElements, ""); cutlass::multiplies multiply_fragments; if (!Is_first) { auto out_reshaped = reinterpret_cast(out); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_dispatch.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_dispatch.cu index 344aa07dd3b..7748a779a82 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_dispatch.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_dispatch.cu @@ -88,7 +88,7 @@ void run_fmha_loop_(Launch_params &launch_params, }); } -void run_fmha_fprop(Launch_params &launch_params, +TORCH_API void run_fmha_fprop(Launch_params &launch_params, const bool configure) { BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] { using elem_type = std::conditional::type; diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/gmem_tile.h b/aten/src/ATen/native/transformers/cuda/flash_attn/gmem_tile.h index 0102c0611be..ea54086ac36 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/gmem_tile.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/gmem_tile.h @@ -118,7 +118,7 @@ struct Gmem_tile_mma_s : public Base { // Store to global memory. template inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){ - static_assert(Fragment::kStorageElements == 4); + static_assert(Fragment::kStorageElements == 4, ""); #pragma unroll for( int mi = 0; mi < M; mi++ ) { #pragma unroll diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h index cfd8b885778..9c630fbd4fe 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h @@ -85,7 +85,7 @@ struct FMHA_kernel_traits { #endif using ElementAccum = float; - static_assert(WARPS_M == 1); + static_assert(WARPS_M == 1, ""); using ThreadblockShapeQK = cutlass::gemm::GemmShape; using WarpCountQK = cutlass::gemm::GemmShape; using WarpShapeQK = cutlass::gemm::GemmShape< @@ -144,7 +144,7 @@ struct FMHA_kernel_traits { static constexpr size_t BYTES_PER_SMEM_Q = ThreadblockShapeQK::kM * ThreadblockShapeQK::kK * sizeof(Element); static constexpr size_t BYTES_PER_SMEM_K = ThreadblockShapeQK::kN * ThreadblockShapeQK::kK * sizeof(Element); static constexpr size_t BYTES_PER_SMEM_V = ThreadblockShapePV::kN * ThreadblockShapePV::kK * sizeof(Element); - static_assert(BYTES_PER_SMEM_K == BYTES_PER_SMEM_V); + static_assert(BYTES_PER_SMEM_K == BYTES_PER_SMEM_V, ""); static constexpr size_t BYTES_PER_SMEM_QK = BYTES_PER_SMEM_Q + BYTES_PER_SMEM_K; // The extra amount of shared memory needed to load V. static constexpr size_t BYTES_PER_SMEM_V_EXTRA = SHARE_SMEM_FOR_K_AND_V ? 0u : BYTES_PER_SMEM_V; diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h b/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h index 6af873aa336..2e121d0e931 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h @@ -78,7 +78,7 @@ struct Smem_tile_reduce { static constexpr int ROWS = WARPS_M * MMAS_M * 16; static constexpr int COLS = WARPS_N; - static_assert(COLS == 4 || COLS == 8); + static_assert(COLS == 4 || COLS == 8, ""); static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); static constexpr int ELTS_PER_TILE = ROWS * COLS; @@ -263,7 +263,7 @@ struct Softmax_base { }; #pragma unroll for( int mi = 0; mi < MMAS_M; mi++ ) { - static_assert(MMAS_N % 2 == 0); + static_assert(MMAS_N % 2 == 0, ""); #pragma unroll for( int ni = 0; ni < MMAS_N; ni += 2 ) { uint4 random_uint4 = ph0(); @@ -319,7 +319,7 @@ struct Softmax : public Softmax_base { static constexpr int MMAS_N = Base::MMAS_N; using Smem_tile_red = Smem_tile_reduce; - static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N); + static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N, ""); // Ctor. template inline __device__ Softmax(const Params ¶ms, void *smem, int tidx) diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/summary_stats.h b/aten/src/ATen/native/transformers/cuda/flash_attn/summary_stats.h index 812aaea7977..a3abda34b4e 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/summary_stats.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/summary_stats.h @@ -12,7 +12,7 @@ template struct Smem_tile_softmax_lse { static constexpr int kMmaM = (kRows / kWarpCountM) / kRowsPerMma; - static_assert(kMmaM * kRowsPerMma * kWarpCountM == kRows); + static_assert(kMmaM * kRowsPerMma * kWarpCountM == kRows, ""); // static_assert(kWarpCountM == 1); // Otherwise we might need to check warp_idx / kWarpCountM == 0 instead of just warp_idx == 0 diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h b/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h index e70f634c26d..7caa29f2086 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h @@ -332,7 +332,7 @@ __device__ inline T operator()(T const & x, T const & y) { return x + y; } template struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4, ""); template static __device__ inline T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; diff --git a/aten/src/ATen/native/transformers/transformer.cpp b/aten/src/ATen/native/transformers/transformer.cpp index b7ef1a36a7d..6e376ae7241 100644 --- a/aten/src/ATen/native/transformers/transformer.cpp +++ b/aten/src/ATen/native/transformers/transformer.cpp @@ -107,13 +107,14 @@ Tensor transformer_encoder_layer_forward( if (norm_first) { x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor); } -#if USE_FLASH_ATTENTION + +#if BETTER_TRANSFORMER_USE_FLASH_ATTENTION if (x.is_nested() && x.is_cuda() && x.dtype() == at::kHalf && !mask.has_value() && (embed_dim / num_heads == 16 || embed_dim / num_heads == 32 || embed_dim / num_heads == 64 || embed_dim / num_heads == 128)) { - TORCH_WARN_ONCE("USING FLASH ATTENTION WITH NT"); + TORCH_WARN_ONCE("transformer_encoder_layer_forward is using flash attention."); x = at::linear(x, qkv_weight, qkv_bias); x = x.view({x.size(0), -1, 3, num_heads, embed_dim / num_heads}); x = flash_attention_helper(x, x, x, 0.0, false); @@ -135,7 +136,7 @@ Tensor transformer_encoder_layer_forward( false /* need_weights */, true /* average_attn_weights */, mask_type)); -#if USE_FLASH_ATTENTION +#if BETTER_TRANSFORMER_USE_FLASH_ATTENTION } #endif add_in_place(x, src, use_nested_tensor); diff --git a/build_variables.bzl b/build_variables.bzl index 9043349deae..4e2a2498413 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1461,6 +1461,7 @@ aten_cuda_cu_source_list = [ "aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp", "aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.cpp", "aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp", + "aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp", ] # Files using thrust::sort_by_key need to be linked last