Update sdp dispatch logic to enable fused backward (#89154)

# Summary
Reorganizes how the sdp dispatch logic is down in order to enable backwards for fused kernels

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89154
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous 2022-11-21 20:02:09 +00:00 committed by PyTorch MergeBot
parent cf9476554f
commit 1d9e1fca97
15 changed files with 500 additions and 138 deletions

View file

@ -13252,19 +13252,40 @@
CPU, NestedTensorCPU, Meta: _fused_sdp_choice_cpp
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
# Register the math kernel for cpu
- func: _scaled_dot_product_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)
variants: function
dispatch:
CUDA: _scaled_dot_product_attention_forward_cuda
CPU: _scaled_dot_product_attention_forward_math
NestedTensorCUDA: _scaled_dot_product_attention_forward_nested
NestedTensorCPU: _scaled_dot_product_attention_forward_math
Meta: _scaled_dot_product_attention_forward_math
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)
variants: function
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool return_softmax=False, bool is_causal=False) -> (Tensor, Tensor, Tensor)
dispatch:
CUDA: _scaled_dot_product_flash_attention_cuda
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor)
dispatch:
CUDA: _scaled_dot_product_efficient_attention_cuda
NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda
- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor)
dispatch:
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
# Returns ouput, softmax_logsumexp, softmax
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, bool return_softmax, float dropout_p, bool is_causal) -> (Tensor, Tensor, Tensor)
variants: function
dispatch:
CUDA: _flash_attention_forward
# Returns ouput, logsumexp if compute_logsumexp
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor)
variants: function
dispatch:
CUDA: _efficient_attention_forward
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor)
variants: function
dispatch:
CUDA: _efficient_attention_backward
- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor
variants: function
dispatch:
@ -13290,21 +13311,6 @@
structured: True
variants: function
- func: _flash_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal) -> Tensor
variants: function
dispatch:
CUDA: flash_scaled_dot_product_attention
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor)
variants: function
dispatch:
CUDA: _efficient_attention_forward
- func: _efficient_attention_backward(Tensor grad, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor out, bool is_causal=False) -> (Tensor, Tensor, Tensor)
variants: function
dispatch:
CUDA: _efficient_attention_backward
- func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor)
variants: function
dispatch:

View file

@ -214,26 +214,6 @@ Tensor NestedTensor_to_padded_tensor_cuda(
return NestedTensor_to_padded_tensor_generic(t, padding, output_size);
}
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_nested(
const Tensor& query_, const Tensor& key, const Tensor& value,
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) {
// Determine which efficient kernel to use
sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal};
auto backend = select_sdp_backend(kernel_params);
switch(backend){
case sdp::SDPBackend::flash_attention:
// TODO: enable flash attention kernel
return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal);
case sdp::SDPBackend::efficient_attention:
return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal);
case sdp::SDPBackend::math:
return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
default:
TORCH_CHECK(false, "Unsupported backend for scaled_dot_product_attention");
return std::make_tuple(Tensor(), Tensor());
}
}
namespace{
/**
@ -340,19 +320,80 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) {
}
} // namespace
std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_nestedtensor_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
double dropout_p,
bool need_atten_weights,
bool return_softmax,
bool is_causal) {
TORCH_CHECK(false, "There are currently cuda memory errors being returned from this path.")
// Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
// Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
// Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
const int64_t num_heads = query.size(1);
const int64_t head_dim = query.size(3);
// Query -> Query (Batch x {Q_seq_len} x Num_heads x Dim_per_head)
// Key -> Key (Batch x {KV_seq_len} x Num_heads x Dim_per_head)
// Value -> Value (Batch x {KV_seq_len} x Num_heads x Dim_per_head)
Tensor q_t = query.transpose(1, 2).contiguous();
Tensor k_t = key.transpose(1, 2).contiguous();
Tensor v_t = value.transpose(1, 2).contiguous();
// K and V have to have the same Nnz, should probably torch_check
// assume in order to not iterate over v
auto cumulative_and_max_q = cumulative_and_max_seq_len(q_t);
auto cumulative_and_max_k = cumulative_and_max_seq_len(k_t);
Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q);
Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k);
const int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q);
const int64_t max_seqlen_batch_k = std::get<1>(cumulative_and_max_k);
const int64_t Nnz_q = cumulative_sequence_length_q[-1].item<int64_t>();
const int64_t Nnz_kv = cumulative_sequence_length_k[-1].item<int64_t>();
auto query_buffer_reshaped =
get_buffer(q_t).view({Nnz_q, num_heads, head_dim});
auto key_buffer_reshaped =
get_buffer(k_t).view({Nnz_kv, num_heads, head_dim});
auto value_buffer_reshaped =
get_buffer(v_t).view({Nnz_kv, num_heads, head_dim});
auto attention_and_lse_and_softmax =
at::_flash_attention_forward(
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k,
return_softmax,
dropout_p,
is_causal);
// Reshape output to convert nnz to batch_size and seq_len
Tensor attention = std::get<0>(attention_and_lse_and_softmax);
attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()).transpose(1,2);
return std::tie(attention, std::get<1>(attention_and_lse_and_softmax), std::get<2>(attention_and_lse_and_softmax));
}
std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_nestedtensor_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
bool compute_log_sumexp,
bool is_causal) {
// Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
// Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
// Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
const int64_t num_heads = query.size(1);
const int64_t head_dim = query.size(3);
Tensor q_t = query.transpose(1, 2);
Tensor k_t = key.transpose(1, 2);
Tensor v_t = value.transpose(1, 2);
@ -432,7 +473,7 @@ std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
{Nnz_kv, num_heads, head_dim},
{nnz_v_stride, head_v_stride, head_dim_stride},
value_impl->get_storage_offsets()[0]);
std::tuple<Tensor, Tensor> attention_and_weights =
std::tuple<Tensor, Tensor> attention_and_logsumexp=
at::_efficient_attention_forward(
query_buffer_reshaped.unsqueeze(0),
key_buffer_reshaped.unsqueeze(0),
@ -440,14 +481,14 @@ std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
cumulative_sequence_length_q,
cumulative_sequence_length_k,
max_seqlen_batch_q,
false,
false);
compute_log_sumexp,
is_causal);
// Reshape output to convert nnz to batch_size and seq_len
Tensor attention = std::get<0>(attention_and_weights);
Tensor attention = std::get<0>(attention_and_logsumexp);
attention =
wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone())
.transpose(1, 2);
return std::tie(attention, std::get<1>(attention_and_weights));
return std::tie(attention, std::get<1>(attention_and_logsumexp));
}
Tensor flash_attention_helper(
@ -492,7 +533,7 @@ Tensor flash_attention_helper(
// If we are passing in query, key, value all the same tensors then we have
// packed them into one tensor and need to slice for flash attention
Tensor attention =
at::_flash_scaled_dot_product_attention(
std::get<0>(at::_flash_attention_forward(
q,
k,
v,
@ -500,8 +541,9 @@ Tensor flash_attention_helper(
cumulative_sequence_length_q,
max_seqlen_batch_q,
max_seqlen_batch_q,
false /*return_softmax*/,
dropout_p,
is_causal);
is_causal));
// Output of flash_attention is a regular tensor lets wrap it back up to
// form a nested tensor

View file

@ -678,20 +678,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> native_decoder_only_multi_head_attent
// L: Target sequence length
// E: Embedding dimension
std::tuple<Tensor, Tensor> _scaled_dot_product_attention(
const Tensor& query_, const Tensor& key, const Tensor& value,
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) {
if (query_.requires_grad() || key.requires_grad() || value.requires_grad()){
return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
}
return at::_scaled_dot_product_attention_forward(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
}
int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){
return static_cast<int64_t>(sdp::SDPBackend::math);
}
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_math(
const Tensor& query_,
const Tensor& key,
const Tensor& value,
@ -699,14 +685,49 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_math(
double dropout_p,
bool need_attn_weights,
bool is_causal) {
return at::_scaled_dot_product_attention_math(
query_,
key,
value,
attn_mask_,
dropout_p,
need_attn_weights,
is_causal);
// TODO: The second return is the attention weights if the math kernel is
// used. The fused kernels do not return this Tensor so for the fused kernels
// The second return SHOULD always be an empty Tensor, unless need_attn_weights
// is true (in which case the fused kernels would not be called). This blows up
// op_info tests.
int64_t choice_int = at::_fused_sdp_choice(
query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
switch (backend) {
case sdp::SDPBackend::flash_attention: {
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
query_, key, value, dropout_p, need_attn_weights, is_causal);
return std::make_tuple(
std::move(std::get<0>(out_lse_softmax)),
std::move(std::get<2>(out_lse_softmax)));
}
case sdp::SDPBackend::efficient_attention: {
bool compute_logsumexp =
(query_.requires_grad() || key.requires_grad() ||
value.requires_grad());
return at::_scaled_dot_product_efficient_attention(
query_, key, value, compute_logsumexp, is_causal);
}
case sdp::SDPBackend::math:
return at::_scaled_dot_product_attention_math(
query_,
key,
value,
attn_mask_,
dropout_p,
need_attn_weights,
is_causal);
default:
TORCH_CHECK(
false,
"No viable backend for scaled_dot_product_attention was found.");
return std::make_tuple(Tensor(), Tensor());
}
}
int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){
return static_cast<int64_t>(sdp::SDPBackend::math);
}
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(

View file

@ -678,12 +678,12 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
return std::make_tuple(std::move(proj), std::move(qkt));
}
std::tuple<Tensor, Tensor> flash_attention_helper_dense_unpacked(
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
double dropout_p,
bool need_atten_weights,
bool return_softmax,
bool is_causal) {
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
@ -726,8 +726,9 @@ std::tuple<Tensor, Tensor> flash_attention_helper_dense_unpacked(
Tensor key_reshaped = k_t.reshape({Nnz_kv, num_heads, head_dim});
Tensor value_reshaped = v_t.reshape({Nnz_kv, num_heads, head_dim});
Tensor attention =
at::_flash_scaled_dot_product_attention(
Tensor attention, log_sumexp, softmax;
std::tie(attention, log_sumexp, softmax) =
at::_flash_attention_forward(
query_reshaped,
key_reshaped,
value_reshaped,
@ -735,15 +736,17 @@ std::tuple<Tensor, Tensor> flash_attention_helper_dense_unpacked(
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k,
return_softmax,
dropout_p,
is_causal);
// Reshape output to convert nnz to batch_size and seq_len
attention =
attention.view({batch_size, max_seqlen_batch_q, num_heads, head_dim}).transpose(1,2);
return std::tuple<Tensor, Tensor>(attention, Tensor());
return std::make_tuple(attention, log_sumexp, softmax);
}
std::tuple<Tensor, Tensor> mem_eff_helper(
std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
@ -767,26 +770,7 @@ std::tuple<Tensor, Tensor> mem_eff_helper(
compute_log_sumexp,
is_causal);
attention = attention.transpose(1,2);
return std::make_tuple(std::move(attention), Tensor());
}
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_cuda(
const Tensor& query_, const Tensor& key, const Tensor& value,
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) {
// Determine which efficient kernel to use
sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal};
auto backend = select_sdp_backend(kernel_params);
switch(backend){
case sdp::SDPBackend::flash_attention:
return flash_attention_helper_dense_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal);
case sdp::SDPBackend::efficient_attention:
return mem_eff_helper(query_, key , value, need_attn_weights, is_causal);
case sdp::SDPBackend::math:
return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
default:
TORCH_CHECK(false, "No viable backend for scaled_dot_product_attention was found.");
return std::make_tuple(Tensor(), Tensor());
}
return std::make_tuple(std::move(attention), std::move(log_sumexp));
}
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
@ -802,7 +786,7 @@ int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Te
return static_cast<int64_t>(backend);
}
Tensor flash_scaled_dot_product_attention(
std::tuple<Tensor, Tensor, Tensor> _flash_attention_forward(
const Tensor& query,
const Tensor& key,
const Tensor& value,
@ -810,11 +794,12 @@ Tensor flash_scaled_dot_product_attention(
const Tensor& cumulative_sequence_length_k,
const int64_t max_seqlen_batch_q,
const int64_t max_seqlen_batch_k,
bool return_softmax,
double dropout_p,
bool is_causal) {
#if defined(USE_FLASH_ATTENTION)
auto softmax_scale = std::pow(query.size(-1), -0.5);
std::vector<Tensor> output = fmha::mha_fwd(
return fmha::mha_fwd(
query,
key,
value,
@ -826,12 +811,11 @@ Tensor flash_scaled_dot_product_attention(
softmax_scale,
false,
is_causal,
false,
return_softmax,
c10::nullopt);
return output[0];
#endif
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
return Tensor();
return std::make_tuple(Tensor(), Tensor(), Tensor());
}
std::tuple<at::Tensor, at::Tensor> _efficient_attention_forward(

View file

@ -10,6 +10,7 @@
#include <ATen/native/transformers/attention.h>
#include <ATen/native/transformers/cuda/sdp_utils.h>
#include <iostream>
#ifdef USE_FLASH_ATTENTION
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
#endif
@ -73,14 +74,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& logsumexp,
const at::Tensor& out,
const at::Tensor& logsumexp,
bool causal) {
#if defined(USE_FLASH_ATTENTION)
if (!grad_out_.defined()) {
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
}
// ndim
// ndim
TORCH_CHECK(query.dim() == grad_out_.dim());
TORCH_CHECK(query.dim() == key.dim());
TORCH_CHECK(query.dim() == value.dim());
@ -128,6 +129,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward(
// initialized
bool grad_kv_needs_init = causal && N > M;
at::Tensor grad_q, grad_k, grad_v;
int8_t gQKV_strideM_multiplier = 1;
if (!grad_kv_needs_init && query.size(1) == key.size(1) &&
query.size(3) == value.size(3) &&
query.storage().is_alias_of(key.storage()) &&
@ -141,10 +143,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward(
grad_q = chunk.select(2, 0);
grad_k = chunk.select(2, 1);
grad_v = chunk.select(2, 2);
gQKV_strideM_multiplier=3;
} else {
grad_q = at::empty_like(query);
grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key);
grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value);
grad_q = at::empty(query.sizes(), query.options());
grad_k = grad_kv_needs_init ? at::zeros(key.sizes(), key.options())
: at::empty(key.sizes(), key.options());
grad_v = grad_kv_needs_init ? at::zeros(value.sizes(), value.options())
: at::empty(value.sizes(), value.options());
}
auto launchKernel = [&](auto _k, int computeCapability) {
@ -198,7 +203,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward(
ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2));
ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2));
ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2));
p.gQKV_strideM_multiplier = grad_q.is_contiguous() ? 1 : 3;
p.gQKV_strideM_multiplier = gQKV_strideM_multiplier;
TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1));
TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1));
TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1));
@ -257,5 +262,28 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward(
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_efficient_attention_backward_cuda(
const at::Tensor& grad_out_,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& out,
const at::Tensor& logsumexp,
bool causal){
if (!grad_out_.defined()) {
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
}
auto grad_out = grad_out_.transpose(1, 2);
auto out_t = out.transpose(1, 2);
auto q_t = query.transpose(1, 2);
auto k_t = key.transpose(1, 2);
auto v_t = value.transpose(1, 2);
Tensor grad_q, grad_k, grad_v;
std::tie(grad_q, grad_k, grad_v) = at::_efficient_attention_backward(grad_out, q_t, k_t, v_t, out_t, logsumexp, causal);
return std::make_tuple(grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2));
}
} // namespace native
} // namespace at

View file

@ -26,6 +26,7 @@
*
******************************************************************************/
#include <tuple>
#ifdef USE_FLASH_ATTENTION
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
@ -115,7 +116,7 @@ void set_params_fprop(FMHA_fprop_params &params,
params.is_causal = is_causal;
}
std::vector<at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor>
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
@ -241,9 +242,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
run_fmha_fprop(launch_params, /*configure=*/false);
std::vector<at::Tensor> result = {o, softmax_lse};
if (return_softmax) {result.push_back(s);}
return result;
return std::make_tuple(o, softmax_lse, s);
}
} // namespace fmha
#endif

View file

@ -7,7 +7,7 @@
namespace fmha {
TORCH_API
std::vector<at::Tensor>
std::tuple<at::Tensor, at::Tensor, at::Tensor>
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i

View file

@ -91,6 +91,31 @@ inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) {
return true;
}
inline bool check_for_nested_inputs(sdp_params params, bool debug){
if (params.query.is_nested() || params.key.is_nested() || params.value.is_nested()) {
TORCH_CHECK(!debug, "We are not enabling nested Tensors for Flash Attention because of cuda memory errors.");
return false;
}
return true;
}
inline bool check_requires_grad(sdp_params params, bool debug) {
if (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad()) {
TORCH_CHECK(!debug, "Flash Attention does not currently support training.");
return false;
}
return true;
}
inline bool check_requires_grad_and_nested(sdp_params params, bool debug) {
// If we fail both checks then we return false
if (!check_for_nested_inputs(params, false) && !check_requires_grad(params,false)){
TORCH_CHECK(!debug, "Memory efficient attention currently doesn't support training with NT inputs.");
return false;
}
return true;
}
inline bool check_for_attn_mask(sdp_params params, bool debug) {
if (params.has_attn_mask) {
TORCH_CHECK(!debug, "Flash Attention does not support attention mask.");
@ -198,13 +223,15 @@ inline bool use_flash_attention(sdp_params params, bool debug) {
return false;
#endif
// Define gate functions that determine if a flash kernel can be ran
constexpr std::array<bool(*)(sdp_params, bool), 7> constraints {{
constexpr std::array<bool(*)(sdp_params, bool), 9> constraints {{
check_runtime_disabled_flash,
check_requires_grad,
check_tensor_shapes,
check_for_attn_weights,
check_for_attn_mask,
check_head_dim_size,
check_gpu_sm75_or_greater,
check_for_nested_inputs,
check_for_seq_len_1_nested_tensor}};
for (auto& constraint : constraints) {
if (!constraint(params, debug)) {
@ -232,14 +259,15 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
at::kHalf, at::kFloat, at::kBFloat16};
// Define gate functions that determine if a flash kernel can be ran
std::vector<std::function<bool(sdp_params, bool)>> constraints{
constexpr std::array<bool(*)(sdp_params, bool), 8> constraints{{
check_gpu_sm50_or_greater,
check_runtime_disabled_mem_efficient,
check_requires_grad_and_nested,
check_for_attn_weights,
check_tensor_shapes,
check_for_attn_mask,
check_for_seq_len_1_nested_tensor,
check_for_non_zero_dropout};
check_for_non_zero_dropout}};
for (auto& constraint : constraints) {
if (!constraint(params, debug)) {
return false;

View file

@ -0,0 +1,189 @@
import torch
import numpy as np
import random
import torch.utils.benchmark as benchmark
from torch.profiler import profile, record_function, ProfilerActivity
class CompositeMHA(torch.nn.Module):
def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
super().__init__()
self.in_proj_weight = in_proj_weight
self.in_proj_bias = in_proj_bias
self.out_proj = out_proj
self.num_heads = num_heads
def forward(self, query, key, value, mask):
if not (query is key and key is value):
raise NotImplementedError(
"query, key and value must be the same Tensor for now."
)
if mask is not None:
raise NotImplementedError("mask is currently not supported.")
query_projected = torch.nn.functional.linear(
query, self.in_proj_weight, self.in_proj_bias
)
batch_size = query_projected.size(0)
embed_dim = query_projected.size(2)
head_dim = embed_dim // (self.num_heads * 3)
query, key, value = query_projected.chunk(3, -1)
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
attn, _ = torch.nn.functional._scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
need_attn_weights=False,
is_causal=False,
)
attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
# Match return signature of nn.MHA
return self.out_proj(attn)
def build_composite_mha_from_nn_mha(pt):
assert pt._qkv_same_embed_dim
in_proj_weight = pt.in_proj_weight
assert in_proj_weight is not None
assert pt.batch_first
return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj)
def forw_back(model, input, upward):
output = model(*input)
output.backward(upward)
# Context manger not working in timer
def forw_back_fused(model, input, upward):
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
output = model(*input)
output.backward(upward)
def forw_back_eager(model, input, upward):
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
output = model(*input)
output.backward(upward)
def run_timing(
min_run_time, batch_size, embed_dimension, num_heads, max_sequence_len, dtype
):
dropout_p = 0.0
mask = None
pt = torch.nn.MultiheadAttention(
embed_dim=embed_dimension,
num_heads=num_heads,
batch_first=True,
dropout=dropout_p,
)
npt = pt.cuda().to(dtype)
cpt = build_composite_mha_from_nn_mha(npt)
x = torch.randn(
batch_size,
max_sequence_len,
embed_dimension,
dtype=dtype,
device="cuda",
requires_grad=True,
)
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
rand_fused_upward = cpt(x, x, x, mask).clone().detach()
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
rand_eager_upward = cpt(x, x, x, mask).clone().detach()
t0 = benchmark.Timer(
stmt="forw_back_fused(cpt, (x,x,x,mask), rand_fused_upward)",
globals={
"forw_back_fused": forw_back_fused,
"cpt": cpt,
"x": x,
"rand_fused_upward": rand_fused_upward,
"mask": mask,
},
label=f"Fused SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} "
f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}",
num_threads=torch.get_num_threads(),
)
t1 = benchmark.Timer(
stmt="forw_back_eager(cpt, (x,x,x,mask), rand_eager_upward)",
globals={
"forw_back_eager": forw_back_eager,
"cpt": cpt,
"x": x,
"rand_eager_upward": rand_eager_upward,
"mask": mask,
},
label=f"Eager SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} "
f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}",
num_threads=torch.get_num_threads(),
)
m0 = t0.blocked_autorange(min_run_time=min_run_time)
m1 = t1.blocked_autorange(min_run_time=min_run_time)
print(m0)
print(m1)
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
print("Profile for Fused".center(200, "-"))
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
with profile(
activities=activities, record_shapes=False, with_stack=True
) as prof:
with record_function("Fused SDP forward and backward"):
for _ in range(20):
forw_back(cpt, (x, x, x, mask), rand_fused_upward)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
print("Profile for eager".center(200, "-"))
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
with profile(
activities=activities, record_shapes=False, with_stack=True
) as prof:
with record_function("Fused SDP forward and backward"):
for _ in range(20):
forw_back(cpt, (x, x, x, mask), rand_eager_upward)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
def main():
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
min_run_time = 10
batch_size = 64
num_heads = 32
max_seq_len = 256
embed_dim = 1024
dtype = torch.bfloat16
print(
f"Running timing for batch_size={batch_size} max_sequence_len={max_seq_len} "
f"num_heads={num_heads} embed_dimension={embed_dim} dtype={dtype}"
)
run_timing(min_run_time, batch_size, embed_dim, num_heads, max_seq_len, dtype)
if __name__ == "__main__":
main()

View file

@ -317,6 +317,9 @@ ALLOW_LIST = [
("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)),
("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)),
("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)),
("aten::_flash_scaled_dot_product_attention", datetime.date(2022, 12, 15)),
("aten::_scaled_dot_product_attention_forward", datetime.date(2022, 12, 15)),
("aten::_efficient_attention_backward", datetime.date(2022, 12, 15)),
("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)),
]

View file

@ -401,6 +401,7 @@ class TestOperators(TestCase):
skip('nn.functional.max_unpool2d'), # fails everywhere except on windows
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
xfail("native_batch_norm"),
xfail('nn.functional._scaled_dot_product_attention', device_type='cuda'),
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
@ -555,6 +556,7 @@ class TestOperators(TestCase):
xfail('nn.functional.ctc_loss'), # Not Implemented
xfail('native_layer_norm', ''), # Expected a proper Tensor but got None for argument #1 'other'
xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides
skip('nn.functional._scaled_dot_product_attention', device_type='cuda'),
# AssertionError: Tensor-likes are not close!
# Mismatched elements: 1 / 15 (6.7%)
# Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
@ -649,7 +651,7 @@ class TestOperators(TestCase):
skip("nn.functional.feature_alpha_dropout", "with_train"), # calls random op
skip("nn.functional.fractional_max_pool2d"), # calls random op
skip("nn.functional.fractional_max_pool3d"), # calls random op
skip('nn.functional._scaled_dot_product_attention'), # randomness
xfail('nn.functional._scaled_dot_product_attention'), # randomness
# It looks like you're either (1) calling .item() on a Tensor or
# (2) attempting to use a Tensor in some data-dependent control flow or
# (3) encountering this error in PyTorch internals.
@ -1126,6 +1128,7 @@ class TestOperators(TestCase):
skip('nn.functional.rrelu'), # randomness
skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness
skip('nn.functional._scaled_dot_product_attention', device_type='cuda'),
skip('nn.functional.alpha_dropout'), # randomness
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
skip('to_sparse', ''), # non-dense output
@ -1249,6 +1252,7 @@ class TestOperators(TestCase):
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for log_sigmoid_backward
xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss
xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward
skip('nn.functional._scaled_dot_product_attention', device_type='cuda'),
xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
@ -1369,7 +1373,7 @@ class TestOperators(TestCase):
xfail('nn.functional.dropout2d'), # calls random op
xfail('nn.functional.dropout3d'), # calls random op
xfail('nn.functional.dropout'), # calls random op
skip('nn.functional._scaled_dot_product_attention'), # randomness
xfail('nn.functional._scaled_dot_product_attention'), # randomness
xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition
xfail('nn.functional.alpha_dropout'), # calls randomn op
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # calls random op

View file

@ -294,7 +294,6 @@ CHECK_STRIDES_SKIPS = {
aten._fft_c2r.default,
aten._fft_r2c.default,
aten._linalg_svd.default,
aten._scaled_dot_product_attention_forward.default,
aten.binary_cross_entropy.default,
aten.complex.default,
aten.copysign.Tensor,

View file

@ -1059,6 +1059,11 @@ class TestTransformers(NNTestCase):
if fused_kernel == "flash":
with sdp_kernel(enable_mem_efficient=False, enable_math=False):
# TODO Flash for the nested path is currently not working due to cuda memory issues
if type == "nested":
self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False))
return
actual = torch.nn.functional._scaled_dot_product_attention(
query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False)
elif fused_kernel == "mem_efficient":
@ -1097,28 +1102,73 @@ class TestTransformers(NNTestCase):
@unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system")
@parametrize("contiguous_inputs", [True, False])
def test_efficient_attention_gradcheck(self, contiguous_inputs: bool):
def test_sdp_math_gradcheck(self, contiguous_inputs: bool):
batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64
rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float16, requires_grad=True, packed=True)
batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True)
qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim))
query, key, value = qkv.chunk(3, dim=-1)
query = query.view(batch_size, -1, num_heads, head_dim)
key = key.view(batch_size, -1, num_heads, head_dim)
value = value.view(batch_size, -1, num_heads, head_dim)
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
if contiguous_inputs:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# Normally we would transpose the inputs but the fused kernels expect
# (batch, seq_len, num_heads, head_dim) bump the tolerance since we can only run kernel
# in fp32
assert gradcheck(lambda *args, **kwargs:
wrapper_set_seed(torch.ops.aten._efficient_attention_forward, *args, **kwargs),
(query, key, value, None, None, None, True, False), fast_mode=True, atol=8e-5, rtol=1e-3)
with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False):
assert gradcheck(lambda *args, **kwargs:
wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs),
(query, key, value, None, 0.0, False, False)
)
@unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system")
@parametrize("contiguous_inputs", [True, False])
def test_sdp_fused_grad_against_math(self, contiguous_inputs: bool):
batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True)
qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim))
qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_()
query, key, value = qkv.chunk(3, dim=-1)
query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1)
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
if contiguous_inputs:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
query_lp = query_lp.contiguous()
key_lp = key_lp.contiguous()
value_lp = value_lp.contiguous()
with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False):
out, atten = torch.nn.functional._scaled_dot_product_attention(query, key, value, None, 0.0, False, False)
with sdp_kernel(enable_math=False, enable_mem_efficient=True, enable_flash=False):
out_lp, atten_lp = torch.nn.functional._scaled_dot_product_attention(
query_lp, key_lp, value_lp, None, 0.0, False, False)
rand_upward = torch.rand_like(out)
rand_upward_lp = rand_upward.to(torch.float32)
out.backward(rand_upward)
out_lp.backward(rand_upward_lp)
# Cast up and compare
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
@parametrize("type", ["dense", "nested"])
def test_fused_sdp_choice(self, type: str):
@ -1144,7 +1194,7 @@ class TestTransformers(NNTestCase):
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
if SM80OrLater:
if SM80OrLater and not type == "nested":
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION
else:
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION

View file

@ -2613,9 +2613,13 @@
nested_strides: non_differentiable
# Transformers
- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor)
output_differentiability: [True, False]
query, key, value: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, result0, result1, is_causal)
- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor)
output_differentiability: [True, False]
query, key, value: _efficient_attention_backward(grad, query, key, value, result1, result0, causal)
query, key, value: _efficient_attention_backward(grad, query, key, value, result0, result1, causal)
# fft
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor

View file

@ -12009,16 +12009,21 @@ op_db: List[OpInfo] = [
# This is only failing on Linux Bionic 3.10 Cuda 11.6
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
device_type='cuda', active_if=_get_torch_cuda_version() >= (11, 6)),
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples',
device_type='cuda', dtypes=(torch.float32,)),
# AssertionError: JIT Test does not execute any logic
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# Doesn't support autocasting
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensorNonErroring', 'test_fake_autocast', device_type='cpu'),
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'),
# Forward works for dtype=float64 which is the math path
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
# No meta function
DecorateInfo(unittest.skip("Skipped!"), 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'),
DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'),
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'),
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', device_type='cuda'),
DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'),),
),
UnaryUfuncInfo(