Get flash_attn to compile for CUDA 11.6 linux nightly build (#84941)

This PR only attempts to get this code to compile for all archs so that we can dispatch to it in https://github.com/pytorch/pytorch/pull/84653
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84941
Approved by: https://github.com/drisspg, https://github.com/malfet
This commit is contained in:
cpuhrsch 2022-09-26 20:49:19 +00:00 committed by PyTorch MergeBot
parent 15435325eb
commit 6a04df3ac8
14 changed files with 36 additions and 28 deletions

View file

@ -725,7 +725,10 @@ set(BUILD_ONEDNN_GRAPH OFF)
include(cmake/Dependencies.cmake)
# Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now
option(USE_FLASH_ATTENTION "Whether to build the flash_attention kernel for scaled dot product attention" OFF)
cmake_dependent_option(
USE_FLASH_ATTENTION
"Whether to build the flash_attention kernel for scaled dot product attention" ON
"USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
if(USE_FLASH_ATTENTION)
ADD_DEFINITIONS(-DUSE_FLASH_ATTENTION)
ENDIF()

View file

@ -63,7 +63,7 @@ struct FMHAEpilogue {
Element, ElementC, /*ElementsPerAccess=*/4, ThreadblockShape, typename WarpMma::Shape,
typename WarpMma::Policy::Operator::Shape, typename OutputTileThreadMap::CompactedThreadMap>;
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
static_assert(WarpTileIterator::kIterations == kIterationsStore);
static_assert(WarpTileIterator::kIterations == kIterationsStore, "");
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
using OutputFragment = typename SharedLoadIterator::Fragment;

View file

@ -151,4 +151,4 @@ struct Launch_params{
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
TORCH_API void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);

View file

@ -26,6 +26,7 @@
*
******************************************************************************/
#ifdef USE_FLASH_ATTENTION
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/NativeFunctions.h>
@ -241,3 +242,4 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
return result;
}
} // namespace fmha
#endif

View file

@ -6,6 +6,7 @@
namespace fmha {
TORCH_API
std::vector<at::Tensor>
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i

View file

@ -102,7 +102,7 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
// If V is stored in shared memory, we can't load K using the same shared memory.
static_assert(Kernel_traits::V_IN_REGS);
static_assert(Kernel_traits::V_IN_REGS, "");
static constexpr size_t SMEM_OFFSET_O = Kernel_traits::BYTES_PER_SMEM_Q;
static constexpr size_t SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + sizeof(typename Smem_O::SharedStorage);
@ -161,7 +161,7 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS;
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V);
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V, "");
static constexpr size_t SMEM_OFFSET_V = Kernel_traits::BYTES_PER_SMEM_Q + (SHARE_SMEM_FOR_K_AND_V ? 0 : Kernel_traits::BYTES_PER_SMEM_K);
static constexpr size_t SMEM_OFFSET_O = SMEM_OFFSET_V + Kernel_traits::BYTES_PER_SMEM_V;
@ -298,7 +298,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
// Wind gmem tiles to the correct position.
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0, "");
const int begin_og = begin;
begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin;
const int steps_og = steps;
@ -428,7 +428,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
const int warp_idx = threadIdx.x / 32;
iter_V.add_tile_offset({kIterationsPV * warp_idx, 0});
typename WarpIteratorV::Fragment frag_v[kIterationsPV];
static_assert(WarpIteratorV::Fragment::kStorageElements == 4 * Mma_tile_o::MMAS_N || WarpIteratorV::Fragment::kStorageElements == 2 * Mma_tile_o::MMAS_N );
static_assert(WarpIteratorV::Fragment::kStorageElements == 4 * Mma_tile_o::MMAS_N || WarpIteratorV::Fragment::kStorageElements == 2 * Mma_tile_o::MMAS_N, "");
#pragma unroll
for( int ki = 0; ki < kIterationsPV; ++ki ) {
iter_V.load(frag_v[ki]);
@ -463,8 +463,8 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
gemm_q_k(mma_qk, acc_p);
typename Smem_O::OutputFragment out[Smem_O::kIterationsStore];
static_assert(GmemIteratorOAccum::kIterations == Smem_O::kIterationsStore);
static_assert(GmemIteratorO::kIterations == Smem_O::kIterationsStore);
static_assert(GmemIteratorOAccum::kIterations == Smem_O::kIterationsStore, "");
static_assert(GmemIteratorO::kIterations == Smem_O::kIterationsStore, "");
if (!Is_first) {
#pragma unroll
for (int iter = 0; iter < GmemIteratorOAccum::kIterations; ++iter) {
@ -536,8 +536,8 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint16_t);
}
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M, "");
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N, "");
softmax.pack_noconvert(acc_p);
cutlass::NumericArrayConverter<Element, ElementAccum, decltype(acc_p)::kElements, cutlass::FloatRoundStyle::round_to_nearest> convert_p;
auto frag_p = convert_p(acc_p);
@ -558,13 +558,13 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Declare the accumulators for the 2nd gemm.
WarpMmaPV mma_pv;
typename WarpMmaPV::FragmentC acc_o;
static_assert(WarpMmaPV::FragmentC::kElements == Mma_tile_o::MMAS_M * Mma_tile_o::MMAS_N * 8);
static_assert(WarpMmaPV::FragmentC::kElements == Mma_tile_o::MMAS_M * Mma_tile_o::MMAS_N * 8, "");
acc_o.clear();
// For some reason, WarpMmaPV::FragmentA has length K * N * (8|4) instead of just N * (8|4).
// We have to first cast frag_p to be array of k x (N * (8|4)), then cast each row to be
// an array of WarpMmaPV::FragmentA (which is what mma_pv expects).
static_assert(decltype(frag_p)::kElements == kIterationsPV * Mma_tile_o::MMAS_M * WarpMmaPV::FragmentA::kElements);
static_assert(decltype(frag_p)::kElements == kIterationsPV * Mma_tile_o::MMAS_M * WarpMmaPV::FragmentA::kElements, "");
const auto frag_p_reshaped = reinterpret_cast<const cutlass::Array<Element, WarpMmaPV::FragmentA::kElements> (&)[kIterationsPV]>(frag_p);
#pragma unroll
for( int ki = 0; ki < kIterationsPV; ++ki ) {
@ -589,7 +589,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
}
softmax.reduce_max_after_sync_(p_max_o, rows);
static_assert(Mma_tile_o::MMAS_M == 1);
static_assert(Mma_tile_o::MMAS_M == 1, "");
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
p_max_o[jj][0] *= params.scale_bmm1;
}
@ -601,7 +601,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Make sure the data is in shared memory.
__syncthreads();
static_assert(Mma_tile_o::MMAS_M == 1);
static_assert(Mma_tile_o::MMAS_M == 1, "");
float p_sum_o[kOutputRowsPerThread][Mma_tile_o::MMAS_M];
softmax.reduce_sum_after_sync_(p_sum_o, rows);
if (!Is_first) {
@ -625,7 +625,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Load from shared memory.
using ArrayTypeO = cutlass::Array<ElementAccum, OutputTileThreadMap::kElementsPerAccess>;
static_assert(OutputTileThreadMap::kElementsPerAccess * kOutputRowsPerThread == Smem_O::kIterationsStore * Smem_O::OutputFragment::kElements);
static_assert(OutputTileThreadMap::kElementsPerAccess * kOutputRowsPerThread == Smem_O::kIterationsStore * Smem_O::OutputFragment::kElements, "");
cutlass::multiplies<ArrayTypeO> multiply_fragments;
if (!Is_first) {
auto out_reshaped = reinterpret_cast<ArrayTypeO (&)[kOutputRowsPerThread]>(out);

View file

@ -88,7 +88,7 @@ void run_fmha_loop_(Launch_params<FMHA_fprop_params> &launch_params,
});
}
void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params,
TORCH_API void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) {
BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] {
using elem_type = std::conditional<IsBf16Const, cutlass::bfloat16_t, cutlass::half_t>::type;

View file

@ -118,7 +118,7 @@ struct Gmem_tile_mma_s : public Base {
// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
static_assert(Fragment::kStorageElements == 4);
static_assert(Fragment::kStorageElements == 4, "");
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll

View file

@ -85,7 +85,7 @@ struct FMHA_kernel_traits {
#endif
using ElementAccum = float;
static_assert(WARPS_M == 1);
static_assert(WARPS_M == 1, "");
using ThreadblockShapeQK = cutlass::gemm::GemmShape<STEP, S, D>;
using WarpCountQK = cutlass::gemm::GemmShape<WARPS_M, WARPS_N, 1>;
using WarpShapeQK = cutlass::gemm::GemmShape<
@ -144,7 +144,7 @@ struct FMHA_kernel_traits {
static constexpr size_t BYTES_PER_SMEM_Q = ThreadblockShapeQK::kM * ThreadblockShapeQK::kK * sizeof(Element);
static constexpr size_t BYTES_PER_SMEM_K = ThreadblockShapeQK::kN * ThreadblockShapeQK::kK * sizeof(Element);
static constexpr size_t BYTES_PER_SMEM_V = ThreadblockShapePV::kN * ThreadblockShapePV::kK * sizeof(Element);
static_assert(BYTES_PER_SMEM_K == BYTES_PER_SMEM_V);
static_assert(BYTES_PER_SMEM_K == BYTES_PER_SMEM_V, "");
static constexpr size_t BYTES_PER_SMEM_QK = BYTES_PER_SMEM_Q + BYTES_PER_SMEM_K;
// The extra amount of shared memory needed to load V.
static constexpr size_t BYTES_PER_SMEM_V_EXTRA = SHARE_SMEM_FOR_K_AND_V ? 0u : BYTES_PER_SMEM_V;

View file

@ -78,7 +78,7 @@ struct Smem_tile_reduce {
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
static constexpr int COLS = WARPS_N;
static_assert(COLS == 4 || COLS == 8);
static_assert(COLS == 4 || COLS == 8, "");
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
static constexpr int ELTS_PER_TILE = ROWS * COLS;
@ -263,7 +263,7 @@ struct Softmax_base {
};
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
static_assert(MMAS_N % 2 == 0);
static_assert(MMAS_N % 2 == 0, "");
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
uint4 random_uint4 = ph0();
@ -319,7 +319,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
static constexpr int MMAS_N = Base::MMAS_N;
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N, "");
// Ctor.
template<typename Params>
inline __device__ Softmax(const Params &params, void *smem, int tidx)

View file

@ -12,7 +12,7 @@ template<int kRows, int kRowsPerMma, int kWarpCountM>
struct Smem_tile_softmax_lse {
static constexpr int kMmaM = (kRows / kWarpCountM) / kRowsPerMma;
static_assert(kMmaM * kRowsPerMma * kWarpCountM == kRows);
static_assert(kMmaM * kRowsPerMma * kWarpCountM == kRows, "");
// static_assert(kWarpCountM == 1);
// Otherwise we might need to check warp_idx / kWarpCountM == 0 instead of just warp_idx == 0

View file

@ -332,7 +332,7 @@ __device__ inline T operator()(T const & x, T const & y) { return x + y; }
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4, "");
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;

View file

@ -107,13 +107,14 @@ Tensor transformer_encoder_layer_forward(
if (norm_first) {
x = norm(x, embed_dim, layer_norm_eps, layer_norm_weight_1, layer_norm_bias_1, use_nested_tensor);
}
#if USE_FLASH_ATTENTION
#if BETTER_TRANSFORMER_USE_FLASH_ATTENTION
if (x.is_nested() && x.is_cuda() && x.dtype() == at::kHalf && !mask.has_value() &&
(embed_dim / num_heads == 16 ||
embed_dim / num_heads == 32 ||
embed_dim / num_heads == 64 ||
embed_dim / num_heads == 128)) {
TORCH_WARN_ONCE("USING FLASH ATTENTION WITH NT");
TORCH_WARN_ONCE("transformer_encoder_layer_forward is using flash attention.");
x = at::linear(x, qkv_weight, qkv_bias);
x = x.view({x.size(0), -1, 3, num_heads, embed_dim / num_heads});
x = flash_attention_helper(x, x, x, 0.0, false);
@ -135,7 +136,7 @@ Tensor transformer_encoder_layer_forward(
false /* need_weights */,
true /* average_attn_weights */,
mask_type));
#if USE_FLASH_ATTENTION
#if BETTER_TRANSFORMER_USE_FLASH_ATTENTION
}
#endif
add_in_place(x, src, use_nested_tensor);

View file

@ -1461,6 +1461,7 @@ aten_cuda_cu_source_list = [
"aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp",
"aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.cpp",
"aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp",
"aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp",
]
# Files using thrust::sort_by_key need to be linked last