mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Get flash_attn to compile for CUDA 11.6 linux nightly build (#84941)
This PR only attempts to get this code to compile for all archs so that we can dispatch to it in https://github.com/pytorch/pytorch/pull/84653 Pull Request resolved: https://github.com/pytorch/pytorch/pull/84941 Approved by: https://github.com/drisspg, https://github.com/malfet
This commit is contained in:
parent
15435325eb
commit
6a04df3ac8
14 changed files with 36 additions and 28 deletions
|
|
@ -725,7 +725,10 @@ set(BUILD_ONEDNN_GRAPH OFF)
|
|||
include(cmake/Dependencies.cmake)
|
||||
|
||||
# Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now
|
||||
option(USE_FLASH_ATTENTION "Whether to build the flash_attention kernel for scaled dot product attention" OFF)
|
||||
cmake_dependent_option(
|
||||
USE_FLASH_ATTENTION
|
||||
"Whether to build the flash_attention kernel for scaled dot product attention" ON
|
||||
"USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
|
||||
if(USE_FLASH_ATTENTION)
|
||||
ADD_DEFINITIONS(-DUSE_FLASH_ATTENTION)
|
||||
ENDIF()
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ struct FMHAEpilogue {
|
|||
Element, ElementC, /*ElementsPerAccess=*/4, ThreadblockShape, typename WarpMma::Shape,
|
||||
typename WarpMma::Policy::Operator::Shape, typename OutputTileThreadMap::CompactedThreadMap>;
|
||||
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
|
||||
static_assert(WarpTileIterator::kIterations == kIterationsStore);
|
||||
static_assert(WarpTileIterator::kIterations == kIterationsStore, "");
|
||||
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
|
||||
using OutputFragment = typename SharedLoadIterator::Fragment;
|
||||
|
||||
|
|
|
|||
|
|
@ -151,4 +151,4 @@ struct Launch_params{
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
|
||||
TORCH_API void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@
|
|||
*
|
||||
******************************************************************************/
|
||||
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
|
@ -241,3 +242,4 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|||
return result;
|
||||
}
|
||||
} // namespace fmha
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
namespace fmha {
|
||||
|
||||
TORCH_API
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
|
|||
|
||||
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
|
||||
// If V is stored in shared memory, we can't load K using the same shared memory.
|
||||
static_assert(Kernel_traits::V_IN_REGS);
|
||||
static_assert(Kernel_traits::V_IN_REGS, "");
|
||||
|
||||
static constexpr size_t SMEM_OFFSET_O = Kernel_traits::BYTES_PER_SMEM_Q;
|
||||
static constexpr size_t SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + sizeof(typename Smem_O::SharedStorage);
|
||||
|
|
@ -161,7 +161,7 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
|
|||
|
||||
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
|
||||
static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS;
|
||||
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V);
|
||||
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V, "");
|
||||
|
||||
static constexpr size_t SMEM_OFFSET_V = Kernel_traits::BYTES_PER_SMEM_Q + (SHARE_SMEM_FOR_K_AND_V ? 0 : Kernel_traits::BYTES_PER_SMEM_K);
|
||||
static constexpr size_t SMEM_OFFSET_O = SMEM_OFFSET_V + Kernel_traits::BYTES_PER_SMEM_V;
|
||||
|
|
@ -298,7 +298,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
|||
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
|
||||
|
||||
// Wind gmem tiles to the correct position.
|
||||
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
|
||||
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0, "");
|
||||
const int begin_og = begin;
|
||||
begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin;
|
||||
const int steps_og = steps;
|
||||
|
|
@ -428,7 +428,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
|||
const int warp_idx = threadIdx.x / 32;
|
||||
iter_V.add_tile_offset({kIterationsPV * warp_idx, 0});
|
||||
typename WarpIteratorV::Fragment frag_v[kIterationsPV];
|
||||
static_assert(WarpIteratorV::Fragment::kStorageElements == 4 * Mma_tile_o::MMAS_N || WarpIteratorV::Fragment::kStorageElements == 2 * Mma_tile_o::MMAS_N );
|
||||
static_assert(WarpIteratorV::Fragment::kStorageElements == 4 * Mma_tile_o::MMAS_N || WarpIteratorV::Fragment::kStorageElements == 2 * Mma_tile_o::MMAS_N, "");
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < kIterationsPV; ++ki ) {
|
||||
iter_V.load(frag_v[ki]);
|
||||
|
|
@ -463,8 +463,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
|||
gemm_q_k(mma_qk, acc_p);
|
||||
|
||||
typename Smem_O::OutputFragment out[Smem_O::kIterationsStore];
|
||||
static_assert(GmemIteratorOAccum::kIterations == Smem_O::kIterationsStore);
|
||||
static_assert(GmemIteratorO::kIterations == Smem_O::kIterationsStore);
|
||||
static_assert(GmemIteratorOAccum::kIterations == Smem_O::kIterationsStore, "");
|
||||
static_assert(GmemIteratorO::kIterations == Smem_O::kIterationsStore, "");
|
||||
if (!Is_first) {
|
||||
#pragma unroll
|
||||
for (int iter = 0; iter < GmemIteratorOAccum::kIterations; ++iter) {
|
||||
|
|
@ -536,8 +536,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
|||
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint16_t);
|
||||
}
|
||||
|
||||
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M, "");
|
||||
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N, "");
|
||||
softmax.pack_noconvert(acc_p);
|
||||
cutlass::NumericArrayConverter<Element, ElementAccum, decltype(acc_p)::kElements, cutlass::FloatRoundStyle::round_to_nearest> convert_p;
|
||||
auto frag_p = convert_p(acc_p);
|
||||
|
|
@ -558,13 +558,13 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
|||
// Declare the accumulators for the 2nd gemm.
|
||||
WarpMmaPV mma_pv;
|
||||
typename WarpMmaPV::FragmentC acc_o;
|
||||
static_assert(WarpMmaPV::FragmentC::kElements == Mma_tile_o::MMAS_M * Mma_tile_o::MMAS_N * 8);
|
||||
static_assert(WarpMmaPV::FragmentC::kElements == Mma_tile_o::MMAS_M * Mma_tile_o::MMAS_N * 8, "");
|
||||
acc_o.clear();
|
||||
|
||||
// For some reason, WarpMmaPV::FragmentA has length K * N * (8|4) instead of just N * (8|4).
|
||||
// We have to first cast frag_p to be array of k x (N * (8|4)), then cast each row to be
|
||||
// an array of WarpMmaPV::FragmentA (which is what mma_pv expects).
|
||||
static_assert(decltype(frag_p)::kElements == kIterationsPV * Mma_tile_o::MMAS_M * WarpMmaPV::FragmentA::kElements);
|
||||
static_assert(decltype(frag_p)::kElements == kIterationsPV * Mma_tile_o::MMAS_M * WarpMmaPV::FragmentA::kElements, "");
|
||||
const auto frag_p_reshaped = reinterpret_cast<const cutlass::Array<Element, WarpMmaPV::FragmentA::kElements> (&)[kIterationsPV]>(frag_p);
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < kIterationsPV; ++ki ) {
|
||||
|
|
@ -589,7 +589,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
|||
}
|
||||
|
||||
softmax.reduce_max_after_sync_(p_max_o, rows);
|
||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||
static_assert(Mma_tile_o::MMAS_M == 1, "");
|
||||
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
|
||||
p_max_o[jj][0] *= params.scale_bmm1;
|
||||
}
|
||||
|
|
@ -601,7 +601,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
|||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||
static_assert(Mma_tile_o::MMAS_M == 1, "");
|
||||
float p_sum_o[kOutputRowsPerThread][Mma_tile_o::MMAS_M];
|
||||
softmax.reduce_sum_after_sync_(p_sum_o, rows);
|
||||
if (!Is_first) {
|
||||
|
|
@ -625,7 +625,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
|||
|
||||
// Load from shared memory.
|
||||
using ArrayTypeO = cutlass::Array<ElementAccum, OutputTileThreadMap::kElementsPerAccess>;
|
||||
static_assert(OutputTileThreadMap::kElementsPerAccess * kOutputRowsPerThread == Smem_O::kIterationsStore * Smem_O::OutputFragment::kElements);
|
||||
static_assert(OutputTileThreadMap::kElementsPerAccess * kOutputRowsPerThread == Smem_O::kIterationsStore * Smem_O::OutputFragment::kElements, "");
|
||||
cutlass::multiplies<ArrayTypeO> multiply_fragments;
|
||||
if (!Is_first) {
|
||||
auto out_reshaped = reinterpret_cast<ArrayTypeO (&)[kOutputRowsPerThread]>(out);
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ void run_fmha_loop_(Launch_params<FMHA_fprop_params> &launch_params,
|
|||
});
|
||||
}
|
||||
|
||||
void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
TORCH_API void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
const bool configure) {
|
||||
BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] {
|
||||
using elem_type = std::conditional<IsBf16Const, cutlass::bfloat16_t, cutlass::half_t>::type;
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ struct Gmem_tile_mma_s : public Base {
|
|||
// Store to global memory.
|
||||
template<typename Mask, typename Fragment>
|
||||
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
|
||||
static_assert(Fragment::kStorageElements == 4);
|
||||
static_assert(Fragment::kStorageElements == 4, "");
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; mi++ ) {
|
||||
#pragma unroll
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ struct FMHA_kernel_traits {
|
|||
#endif
|
||||
using ElementAccum = float;
|
||||
|
||||
static_assert(WARPS_M == 1);
|
||||
static_assert(WARPS_M == 1, "");
|
||||
using ThreadblockShapeQK = cutlass::gemm::GemmShape<STEP, S, D>;
|
||||
using WarpCountQK = cutlass::gemm::GemmShape<WARPS_M, WARPS_N, 1>;
|
||||
using WarpShapeQK = cutlass::gemm::GemmShape<
|
||||
|
|
@ -144,7 +144,7 @@ struct FMHA_kernel_traits {
|
|||
static constexpr size_t BYTES_PER_SMEM_Q = ThreadblockShapeQK::kM * ThreadblockShapeQK::kK * sizeof(Element);
|
||||
static constexpr size_t BYTES_PER_SMEM_K = ThreadblockShapeQK::kN * ThreadblockShapeQK::kK * sizeof(Element);
|
||||
static constexpr size_t BYTES_PER_SMEM_V = ThreadblockShapePV::kN * ThreadblockShapePV::kK * sizeof(Element);
|
||||
static_assert(BYTES_PER_SMEM_K == BYTES_PER_SMEM_V);
|
||||
static_assert(BYTES_PER_SMEM_K == BYTES_PER_SMEM_V, "");
|
||||
static constexpr size_t BYTES_PER_SMEM_QK = BYTES_PER_SMEM_Q + BYTES_PER_SMEM_K;
|
||||
// The extra amount of shared memory needed to load V.
|
||||
static constexpr size_t BYTES_PER_SMEM_V_EXTRA = SHARE_SMEM_FOR_K_AND_V ? 0u : BYTES_PER_SMEM_V;
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ struct Smem_tile_reduce {
|
|||
|
||||
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
|
||||
static constexpr int COLS = WARPS_N;
|
||||
static_assert(COLS == 4 || COLS == 8);
|
||||
static_assert(COLS == 4 || COLS == 8, "");
|
||||
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
|
||||
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
|
||||
static constexpr int ELTS_PER_TILE = ROWS * COLS;
|
||||
|
|
@ -263,7 +263,7 @@ struct Softmax_base {
|
|||
};
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
static_assert(MMAS_N % 2 == 0);
|
||||
static_assert(MMAS_N % 2 == 0, "");
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
|
||||
uint4 random_uint4 = ph0();
|
||||
|
|
@ -319,7 +319,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
|
|||
static constexpr int MMAS_N = Base::MMAS_N;
|
||||
|
||||
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
|
||||
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
|
||||
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N, "");
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Softmax(const Params ¶ms, void *smem, int tidx)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ template<int kRows, int kRowsPerMma, int kWarpCountM>
|
|||
struct Smem_tile_softmax_lse {
|
||||
|
||||
static constexpr int kMmaM = (kRows / kWarpCountM) / kRowsPerMma;
|
||||
static_assert(kMmaM * kRowsPerMma * kWarpCountM == kRows);
|
||||
static_assert(kMmaM * kRowsPerMma * kWarpCountM == kRows, "");
|
||||
// static_assert(kWarpCountM == 1);
|
||||
// Otherwise we might need to check warp_idx / kWarpCountM == 0 instead of just warp_idx == 0
|
||||
|
||||
|
|
|
|||
|
|
@ -332,7 +332,7 @@ __device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
|||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4, "");
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
|
|
|
|||
|
|
@ -107,13 +107,14 @@ Tensor transformer_encoder_layer_forward(
|
|||
if (norm_first) {
|
||||
x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor);
|
||||
}
|
||||
#if USE_FLASH_ATTENTION
|
||||
|
||||
#if BETTER_TRANSFORMER_USE_FLASH_ATTENTION
|
||||
if (x.is_nested() && x.is_cuda() && x.dtype() == at::kHalf && !mask.has_value() &&
|
||||
(embed_dim / num_heads == 16 ||
|
||||
embed_dim / num_heads == 32 ||
|
||||
embed_dim / num_heads == 64 ||
|
||||
embed_dim / num_heads == 128)) {
|
||||
TORCH_WARN_ONCE("USING FLASH ATTENTION WITH NT");
|
||||
TORCH_WARN_ONCE("transformer_encoder_layer_forward is using flash attention.");
|
||||
x = at::linear(x, qkv_weight, qkv_bias);
|
||||
x = x.view({x.size(0), -1, 3, num_heads, embed_dim / num_heads});
|
||||
x = flash_attention_helper(x, x, x, 0.0, false);
|
||||
|
|
@ -135,7 +136,7 @@ Tensor transformer_encoder_layer_forward(
|
|||
false /* need_weights */,
|
||||
true /* average_attn_weights */,
|
||||
mask_type));
|
||||
#if USE_FLASH_ATTENTION
|
||||
#if BETTER_TRANSFORMER_USE_FLASH_ATTENTION
|
||||
}
|
||||
#endif
|
||||
add_in_place(x, src, use_nested_tensor);
|
||||
|
|
|
|||
|
|
@ -1461,6 +1461,7 @@ aten_cuda_cu_source_list = [
|
|||
"aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp",
|
||||
"aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.cpp",
|
||||
"aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp",
|
||||
"aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp",
|
||||
]
|
||||
|
||||
# Files using thrust::sort_by_key need to be linked last
|
||||
|
|
|
|||
Loading…
Reference in a new issue