mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update sdp dispatch logic to enable fused backward (#89154)
# Summary Reorganizes how the sdp dispatch logic is down in order to enable backwards for fused kernels Pull Request resolved: https://github.com/pytorch/pytorch/pull/89154 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
cf9476554f
commit
1d9e1fca97
15 changed files with 500 additions and 138 deletions
|
|
@ -13252,19 +13252,40 @@
|
|||
CPU, NestedTensorCPU, Meta: _fused_sdp_choice_cpp
|
||||
CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
|
||||
|
||||
# Register the math kernel for cpu
|
||||
- func: _scaled_dot_product_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_attention_forward_cuda
|
||||
CPU: _scaled_dot_product_attention_forward_math
|
||||
NestedTensorCUDA: _scaled_dot_product_attention_forward_nested
|
||||
NestedTensorCPU: _scaled_dot_product_attention_forward_math
|
||||
Meta: _scaled_dot_product_attention_forward_math
|
||||
|
||||
- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor)
|
||||
variants: function
|
||||
|
||||
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool return_softmax=False, bool is_causal=False) -> (Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_flash_attention_cuda
|
||||
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
|
||||
|
||||
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_efficient_attention_cuda
|
||||
NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda
|
||||
|
||||
- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
|
||||
|
||||
# Returns ouput, softmax_logsumexp, softmax
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, bool return_softmax, float dropout_p, bool is_causal) -> (Tensor, Tensor, Tensor)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _flash_attention_forward
|
||||
|
||||
# Returns ouput, logsumexp if compute_logsumexp
|
||||
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _efficient_attention_forward
|
||||
|
||||
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _efficient_attention_backward
|
||||
|
||||
- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
|
|
@ -13290,21 +13311,6 @@
|
|||
structured: True
|
||||
variants: function
|
||||
|
||||
- func: _flash_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: flash_scaled_dot_product_attention
|
||||
|
||||
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _efficient_attention_forward
|
||||
|
||||
- func: _efficient_attention_backward(Tensor grad, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor out, bool is_causal=False) -> (Tensor, Tensor, Tensor)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _efficient_attention_backward
|
||||
|
||||
- func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor)
|
||||
variants: function
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -214,26 +214,6 @@ Tensor NestedTensor_to_padded_tensor_cuda(
|
|||
return NestedTensor_to_padded_tensor_generic(t, padding, output_size);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_nested(
|
||||
const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) {
|
||||
|
||||
// Determine which efficient kernel to use
|
||||
sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal};
|
||||
auto backend = select_sdp_backend(kernel_params);
|
||||
switch(backend){
|
||||
case sdp::SDPBackend::flash_attention:
|
||||
// TODO: enable flash attention kernel
|
||||
return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal);
|
||||
case sdp::SDPBackend::efficient_attention:
|
||||
return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal);
|
||||
case sdp::SDPBackend::math:
|
||||
return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported backend for scaled_dot_product_attention");
|
||||
return std::make_tuple(Tensor(), Tensor());
|
||||
}
|
||||
}
|
||||
namespace{
|
||||
|
||||
/**
|
||||
|
|
@ -340,19 +320,80 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_nestedtensor_cuda(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
double dropout_p,
|
||||
bool need_atten_weights,
|
||||
bool return_softmax,
|
||||
bool is_causal) {
|
||||
TORCH_CHECK(false, "There are currently cuda memory errors being returned from this path.")
|
||||
// Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
|
||||
// Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
|
||||
// Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t head_dim = query.size(3);
|
||||
|
||||
// Query -> Query (Batch x {Q_seq_len} x Num_heads x Dim_per_head)
|
||||
// Key -> Key (Batch x {KV_seq_len} x Num_heads x Dim_per_head)
|
||||
// Value -> Value (Batch x {KV_seq_len} x Num_heads x Dim_per_head)
|
||||
Tensor q_t = query.transpose(1, 2).contiguous();
|
||||
Tensor k_t = key.transpose(1, 2).contiguous();
|
||||
Tensor v_t = value.transpose(1, 2).contiguous();
|
||||
|
||||
// K and V have to have the same Nnz, should probably torch_check
|
||||
// assume in order to not iterate over v
|
||||
|
||||
auto cumulative_and_max_q = cumulative_and_max_seq_len(q_t);
|
||||
auto cumulative_and_max_k = cumulative_and_max_seq_len(k_t);
|
||||
|
||||
Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q);
|
||||
Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k);
|
||||
|
||||
const int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q);
|
||||
const int64_t max_seqlen_batch_k = std::get<1>(cumulative_and_max_k);
|
||||
|
||||
const int64_t Nnz_q = cumulative_sequence_length_q[-1].item<int64_t>();
|
||||
const int64_t Nnz_kv = cumulative_sequence_length_k[-1].item<int64_t>();
|
||||
|
||||
auto query_buffer_reshaped =
|
||||
get_buffer(q_t).view({Nnz_q, num_heads, head_dim});
|
||||
auto key_buffer_reshaped =
|
||||
get_buffer(k_t).view({Nnz_kv, num_heads, head_dim});
|
||||
auto value_buffer_reshaped =
|
||||
get_buffer(v_t).view({Nnz_kv, num_heads, head_dim});
|
||||
|
||||
auto attention_and_lse_and_softmax =
|
||||
at::_flash_attention_forward(
|
||||
query_buffer_reshaped,
|
||||
key_buffer_reshaped,
|
||||
value_buffer_reshaped,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
return_softmax,
|
||||
dropout_p,
|
||||
is_causal);
|
||||
// Reshape output to convert nnz to batch_size and seq_len
|
||||
Tensor attention = std::get<0>(attention_and_lse_and_softmax);
|
||||
attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()).transpose(1,2);
|
||||
return std::tie(attention, std::get<1>(attention_and_lse_and_softmax), std::get<2>(attention_and_lse_and_softmax));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_nestedtensor_cuda(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
bool compute_log_sumexp,
|
||||
bool is_causal) {
|
||||
// Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
|
||||
// Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
|
||||
// Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t head_dim = query.size(3);
|
||||
|
||||
Tensor q_t = query.transpose(1, 2);
|
||||
Tensor k_t = key.transpose(1, 2);
|
||||
Tensor v_t = value.transpose(1, 2);
|
||||
|
|
@ -432,7 +473,7 @@ std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
|
|||
{Nnz_kv, num_heads, head_dim},
|
||||
{nnz_v_stride, head_v_stride, head_dim_stride},
|
||||
value_impl->get_storage_offsets()[0]);
|
||||
std::tuple<Tensor, Tensor> attention_and_weights =
|
||||
std::tuple<Tensor, Tensor> attention_and_logsumexp=
|
||||
at::_efficient_attention_forward(
|
||||
query_buffer_reshaped.unsqueeze(0),
|
||||
key_buffer_reshaped.unsqueeze(0),
|
||||
|
|
@ -440,14 +481,14 @@ std::tuple<Tensor, Tensor> mem_efficient_helper_nested_unpacked(
|
|||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
max_seqlen_batch_q,
|
||||
false,
|
||||
false);
|
||||
compute_log_sumexp,
|
||||
is_causal);
|
||||
// Reshape output to convert nnz to batch_size and seq_len
|
||||
Tensor attention = std::get<0>(attention_and_weights);
|
||||
Tensor attention = std::get<0>(attention_and_logsumexp);
|
||||
attention =
|
||||
wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone())
|
||||
.transpose(1, 2);
|
||||
return std::tie(attention, std::get<1>(attention_and_weights));
|
||||
return std::tie(attention, std::get<1>(attention_and_logsumexp));
|
||||
}
|
||||
|
||||
Tensor flash_attention_helper(
|
||||
|
|
@ -492,7 +533,7 @@ Tensor flash_attention_helper(
|
|||
// If we are passing in query, key, value all the same tensors then we have
|
||||
// packed them into one tensor and need to slice for flash attention
|
||||
Tensor attention =
|
||||
at::_flash_scaled_dot_product_attention(
|
||||
std::get<0>(at::_flash_attention_forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
|
@ -500,8 +541,9 @@ Tensor flash_attention_helper(
|
|||
cumulative_sequence_length_q,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_q,
|
||||
false /*return_softmax*/,
|
||||
dropout_p,
|
||||
is_causal);
|
||||
is_causal));
|
||||
// Output of flash_attention is a regular tensor lets wrap it back up to
|
||||
// form a nested tensor
|
||||
|
||||
|
|
|
|||
|
|
@ -678,20 +678,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> native_decoder_only_multi_head_attent
|
|||
// L: Target sequence length
|
||||
// E: Embedding dimension
|
||||
std::tuple<Tensor, Tensor> _scaled_dot_product_attention(
|
||||
const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) {
|
||||
if (query_.requires_grad() || key.requires_grad() || value.requires_grad()){
|
||||
return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
|
||||
}
|
||||
return at::_scaled_dot_product_attention_forward(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
|
||||
}
|
||||
|
||||
int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){
|
||||
return static_cast<int64_t>(sdp::SDPBackend::math);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_math(
|
||||
const Tensor& query_,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
|
|
@ -699,14 +685,49 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_math(
|
|||
double dropout_p,
|
||||
bool need_attn_weights,
|
||||
bool is_causal) {
|
||||
return at::_scaled_dot_product_attention_math(
|
||||
query_,
|
||||
key,
|
||||
value,
|
||||
attn_mask_,
|
||||
dropout_p,
|
||||
need_attn_weights,
|
||||
is_causal);
|
||||
// TODO: The second return is the attention weights if the math kernel is
|
||||
// used. The fused kernels do not return this Tensor so for the fused kernels
|
||||
// The second return SHOULD always be an empty Tensor, unless need_attn_weights
|
||||
// is true (in which case the fused kernels would not be called). This blows up
|
||||
// op_info tests.
|
||||
int64_t choice_int = at::_fused_sdp_choice(
|
||||
query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
|
||||
sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
|
||||
switch (backend) {
|
||||
case sdp::SDPBackend::flash_attention: {
|
||||
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
|
||||
query_, key, value, dropout_p, need_attn_weights, is_causal);
|
||||
return std::make_tuple(
|
||||
std::move(std::get<0>(out_lse_softmax)),
|
||||
std::move(std::get<2>(out_lse_softmax)));
|
||||
}
|
||||
case sdp::SDPBackend::efficient_attention: {
|
||||
bool compute_logsumexp =
|
||||
(query_.requires_grad() || key.requires_grad() ||
|
||||
value.requires_grad());
|
||||
return at::_scaled_dot_product_efficient_attention(
|
||||
query_, key, value, compute_logsumexp, is_causal);
|
||||
}
|
||||
case sdp::SDPBackend::math:
|
||||
return at::_scaled_dot_product_attention_math(
|
||||
query_,
|
||||
key,
|
||||
value,
|
||||
attn_mask_,
|
||||
dropout_p,
|
||||
need_attn_weights,
|
||||
is_causal);
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"No viable backend for scaled_dot_product_attention was found.");
|
||||
return std::make_tuple(Tensor(), Tensor());
|
||||
}
|
||||
}
|
||||
|
||||
int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){
|
||||
return static_cast<int64_t>(sdp::SDPBackend::math);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
||||
|
|
|
|||
|
|
@ -678,12 +678,12 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
|
|||
return std::make_tuple(std::move(proj), std::move(qkt));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> flash_attention_helper_dense_unpacked(
|
||||
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
double dropout_p,
|
||||
bool need_atten_weights,
|
||||
bool return_softmax,
|
||||
bool is_causal) {
|
||||
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
|
||||
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
|
||||
|
|
@ -726,8 +726,9 @@ std::tuple<Tensor, Tensor> flash_attention_helper_dense_unpacked(
|
|||
Tensor key_reshaped = k_t.reshape({Nnz_kv, num_heads, head_dim});
|
||||
Tensor value_reshaped = v_t.reshape({Nnz_kv, num_heads, head_dim});
|
||||
|
||||
Tensor attention =
|
||||
at::_flash_scaled_dot_product_attention(
|
||||
Tensor attention, log_sumexp, softmax;
|
||||
std::tie(attention, log_sumexp, softmax) =
|
||||
at::_flash_attention_forward(
|
||||
query_reshaped,
|
||||
key_reshaped,
|
||||
value_reshaped,
|
||||
|
|
@ -735,15 +736,17 @@ std::tuple<Tensor, Tensor> flash_attention_helper_dense_unpacked(
|
|||
cumulative_sequence_length_k,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
return_softmax,
|
||||
dropout_p,
|
||||
is_causal);
|
||||
// Reshape output to convert nnz to batch_size and seq_len
|
||||
attention =
|
||||
attention.view({batch_size, max_seqlen_batch_q, num_heads, head_dim}).transpose(1,2);
|
||||
|
||||
return std::tuple<Tensor, Tensor>(attention, Tensor());
|
||||
return std::make_tuple(attention, log_sumexp, softmax);
|
||||
}
|
||||
std::tuple<Tensor, Tensor> mem_eff_helper(
|
||||
|
||||
std::tuple<Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
|
|
@ -767,26 +770,7 @@ std::tuple<Tensor, Tensor> mem_eff_helper(
|
|||
compute_log_sumexp,
|
||||
is_causal);
|
||||
attention = attention.transpose(1,2);
|
||||
return std::make_tuple(std::move(attention), Tensor());
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> _scaled_dot_product_attention_forward_cuda(
|
||||
const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
const c10::optional<Tensor>& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) {
|
||||
// Determine which efficient kernel to use
|
||||
sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal};
|
||||
auto backend = select_sdp_backend(kernel_params);
|
||||
switch(backend){
|
||||
case sdp::SDPBackend::flash_attention:
|
||||
return flash_attention_helper_dense_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal);
|
||||
case sdp::SDPBackend::efficient_attention:
|
||||
return mem_eff_helper(query_, key , value, need_attn_weights, is_causal);
|
||||
case sdp::SDPBackend::math:
|
||||
return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal);
|
||||
default:
|
||||
TORCH_CHECK(false, "No viable backend for scaled_dot_product_attention was found.");
|
||||
return std::make_tuple(Tensor(), Tensor());
|
||||
}
|
||||
return std::make_tuple(std::move(attention), std::move(log_sumexp));
|
||||
}
|
||||
|
||||
int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value,
|
||||
|
|
@ -802,7 +786,7 @@ int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Te
|
|||
return static_cast<int64_t>(backend);
|
||||
}
|
||||
|
||||
Tensor flash_scaled_dot_product_attention(
|
||||
std::tuple<Tensor, Tensor, Tensor> _flash_attention_forward(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
|
|
@ -810,11 +794,12 @@ Tensor flash_scaled_dot_product_attention(
|
|||
const Tensor& cumulative_sequence_length_k,
|
||||
const int64_t max_seqlen_batch_q,
|
||||
const int64_t max_seqlen_batch_k,
|
||||
bool return_softmax,
|
||||
double dropout_p,
|
||||
bool is_causal) {
|
||||
#if defined(USE_FLASH_ATTENTION)
|
||||
auto softmax_scale = std::pow(query.size(-1), -0.5);
|
||||
std::vector<Tensor> output = fmha::mha_fwd(
|
||||
return fmha::mha_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
|
|
@ -826,12 +811,11 @@ Tensor flash_scaled_dot_product_attention(
|
|||
softmax_scale,
|
||||
false,
|
||||
is_causal,
|
||||
false,
|
||||
return_softmax,
|
||||
c10::nullopt);
|
||||
return output[0];
|
||||
#endif
|
||||
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
|
||||
return Tensor();
|
||||
return std::make_tuple(Tensor(), Tensor(), Tensor());
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> _efficient_attention_forward(
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/transformers/cuda/sdp_utils.h>
|
||||
|
||||
#include <iostream>
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
|
||||
#endif
|
||||
|
|
@ -73,14 +74,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward(
|
|||
const at::Tensor& query,
|
||||
const at::Tensor& key,
|
||||
const at::Tensor& value,
|
||||
const at::Tensor& logsumexp,
|
||||
const at::Tensor& out,
|
||||
const at::Tensor& logsumexp,
|
||||
bool causal) {
|
||||
#if defined(USE_FLASH_ATTENTION)
|
||||
if (!grad_out_.defined()) {
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
// ndim
|
||||
// ndim
|
||||
TORCH_CHECK(query.dim() == grad_out_.dim());
|
||||
TORCH_CHECK(query.dim() == key.dim());
|
||||
TORCH_CHECK(query.dim() == value.dim());
|
||||
|
|
@ -128,6 +129,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward(
|
|||
// initialized
|
||||
bool grad_kv_needs_init = causal && N > M;
|
||||
at::Tensor grad_q, grad_k, grad_v;
|
||||
int8_t gQKV_strideM_multiplier = 1;
|
||||
if (!grad_kv_needs_init && query.size(1) == key.size(1) &&
|
||||
query.size(3) == value.size(3) &&
|
||||
query.storage().is_alias_of(key.storage()) &&
|
||||
|
|
@ -141,10 +143,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward(
|
|||
grad_q = chunk.select(2, 0);
|
||||
grad_k = chunk.select(2, 1);
|
||||
grad_v = chunk.select(2, 2);
|
||||
gQKV_strideM_multiplier=3;
|
||||
} else {
|
||||
grad_q = at::empty_like(query);
|
||||
grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key);
|
||||
grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value);
|
||||
grad_q = at::empty(query.sizes(), query.options());
|
||||
grad_k = grad_kv_needs_init ? at::zeros(key.sizes(), key.options())
|
||||
: at::empty(key.sizes(), key.options());
|
||||
grad_v = grad_kv_needs_init ? at::zeros(value.sizes(), value.options())
|
||||
: at::empty(value.sizes(), value.options());
|
||||
}
|
||||
|
||||
auto launchKernel = [&](auto _k, int computeCapability) {
|
||||
|
|
@ -198,7 +203,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward(
|
|||
ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2));
|
||||
ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2));
|
||||
ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2));
|
||||
p.gQKV_strideM_multiplier = grad_q.is_contiguous() ? 1 : 3;
|
||||
p.gQKV_strideM_multiplier = gQKV_strideM_multiplier;
|
||||
TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1));
|
||||
TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1));
|
||||
TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1));
|
||||
|
|
@ -257,5 +262,28 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _efficient_attention_backward(
|
|||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_efficient_attention_backward_cuda(
|
||||
const at::Tensor& grad_out_,
|
||||
const at::Tensor& query,
|
||||
const at::Tensor& key,
|
||||
const at::Tensor& value,
|
||||
const at::Tensor& out,
|
||||
const at::Tensor& logsumexp,
|
||||
bool causal){
|
||||
if (!grad_out_.defined()) {
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
auto grad_out = grad_out_.transpose(1, 2);
|
||||
auto out_t = out.transpose(1, 2);
|
||||
auto q_t = query.transpose(1, 2);
|
||||
auto k_t = key.transpose(1, 2);
|
||||
auto v_t = value.transpose(1, 2);
|
||||
|
||||
Tensor grad_q, grad_k, grad_v;
|
||||
std::tie(grad_q, grad_k, grad_v) = at::_efficient_attention_backward(grad_out, q_t, k_t, v_t, out_t, logsumexp, causal);
|
||||
return std::make_tuple(grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2));
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@
|
|||
*
|
||||
******************************************************************************/
|
||||
|
||||
#include <tuple>
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
|
@ -115,7 +116,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms,
|
|||
params.is_causal = is_causal;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
|
|
@ -241,9 +242,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|||
|
||||
run_fmha_fprop(launch_params, /*configure=*/false);
|
||||
|
||||
std::vector<at::Tensor> result = {o, softmax_lse};
|
||||
if (return_softmax) {result.push_back(s);}
|
||||
return result;
|
||||
return std::make_tuple(o, softmax_lse, s);
|
||||
}
|
||||
} // namespace fmha
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
namespace fmha {
|
||||
|
||||
TORCH_API
|
||||
std::vector<at::Tensor>
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
|
|
|
|||
|
|
@ -91,6 +91,31 @@ inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) {
|
|||
return true;
|
||||
}
|
||||
|
||||
inline bool check_for_nested_inputs(sdp_params params, bool debug){
|
||||
if (params.query.is_nested() || params.key.is_nested() || params.value.is_nested()) {
|
||||
TORCH_CHECK(!debug, "We are not enabling nested Tensors for Flash Attention because of cuda memory errors.");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_requires_grad(sdp_params params, bool debug) {
|
||||
if (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad()) {
|
||||
TORCH_CHECK(!debug, "Flash Attention does not currently support training.");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_requires_grad_and_nested(sdp_params params, bool debug) {
|
||||
// If we fail both checks then we return false
|
||||
if (!check_for_nested_inputs(params, false) && !check_requires_grad(params,false)){
|
||||
TORCH_CHECK(!debug, "Memory efficient attention currently doesn't support training with NT inputs.");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_for_attn_mask(sdp_params params, bool debug) {
|
||||
if (params.has_attn_mask) {
|
||||
TORCH_CHECK(!debug, "Flash Attention does not support attention mask.");
|
||||
|
|
@ -198,13 +223,15 @@ inline bool use_flash_attention(sdp_params params, bool debug) {
|
|||
return false;
|
||||
#endif
|
||||
// Define gate functions that determine if a flash kernel can be ran
|
||||
constexpr std::array<bool(*)(sdp_params, bool), 7> constraints {{
|
||||
constexpr std::array<bool(*)(sdp_params, bool), 9> constraints {{
|
||||
check_runtime_disabled_flash,
|
||||
check_requires_grad,
|
||||
check_tensor_shapes,
|
||||
check_for_attn_weights,
|
||||
check_for_attn_mask,
|
||||
check_head_dim_size,
|
||||
check_gpu_sm75_or_greater,
|
||||
check_for_nested_inputs,
|
||||
check_for_seq_len_1_nested_tensor}};
|
||||
for (auto& constraint : constraints) {
|
||||
if (!constraint(params, debug)) {
|
||||
|
|
@ -232,14 +259,15 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
|
|||
at::kHalf, at::kFloat, at::kBFloat16};
|
||||
|
||||
// Define gate functions that determine if a flash kernel can be ran
|
||||
std::vector<std::function<bool(sdp_params, bool)>> constraints{
|
||||
constexpr std::array<bool(*)(sdp_params, bool), 8> constraints{{
|
||||
check_gpu_sm50_or_greater,
|
||||
check_runtime_disabled_mem_efficient,
|
||||
check_requires_grad_and_nested,
|
||||
check_for_attn_weights,
|
||||
check_tensor_shapes,
|
||||
check_for_attn_mask,
|
||||
check_for_seq_len_1_nested_tensor,
|
||||
check_for_non_zero_dropout};
|
||||
check_for_non_zero_dropout}};
|
||||
for (auto& constraint : constraints) {
|
||||
if (!constraint(params, debug)) {
|
||||
return false;
|
||||
|
|
|
|||
189
benchmarks/transformer/sdp_backwards.py
Normal file
189
benchmarks/transformer/sdp_backwards.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import torch.utils.benchmark as benchmark
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
|
||||
class CompositeMHA(torch.nn.Module):
|
||||
def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
|
||||
super().__init__()
|
||||
self.in_proj_weight = in_proj_weight
|
||||
self.in_proj_bias = in_proj_bias
|
||||
self.out_proj = out_proj
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
if not (query is key and key is value):
|
||||
raise NotImplementedError(
|
||||
"query, key and value must be the same Tensor for now."
|
||||
)
|
||||
if mask is not None:
|
||||
raise NotImplementedError("mask is currently not supported.")
|
||||
|
||||
query_projected = torch.nn.functional.linear(
|
||||
query, self.in_proj_weight, self.in_proj_bias
|
||||
)
|
||||
|
||||
batch_size = query_projected.size(0)
|
||||
embed_dim = query_projected.size(2)
|
||||
head_dim = embed_dim // (self.num_heads * 3)
|
||||
|
||||
query, key, value = query_projected.chunk(3, -1)
|
||||
|
||||
query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
attn, _ = torch.nn.functional._scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
need_attn_weights=False,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
|
||||
# Match return signature of nn.MHA
|
||||
return self.out_proj(attn)
|
||||
|
||||
|
||||
def build_composite_mha_from_nn_mha(pt):
|
||||
assert pt._qkv_same_embed_dim
|
||||
in_proj_weight = pt.in_proj_weight
|
||||
assert in_proj_weight is not None
|
||||
assert pt.batch_first
|
||||
return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj)
|
||||
|
||||
|
||||
def forw_back(model, input, upward):
|
||||
output = model(*input)
|
||||
output.backward(upward)
|
||||
|
||||
|
||||
# Context manger not working in timer
|
||||
|
||||
|
||||
def forw_back_fused(model, input, upward):
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
|
||||
output = model(*input)
|
||||
output.backward(upward)
|
||||
|
||||
|
||||
def forw_back_eager(model, input, upward):
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
|
||||
output = model(*input)
|
||||
output.backward(upward)
|
||||
|
||||
|
||||
def run_timing(
|
||||
min_run_time, batch_size, embed_dimension, num_heads, max_sequence_len, dtype
|
||||
):
|
||||
dropout_p = 0.0
|
||||
mask = None
|
||||
|
||||
pt = torch.nn.MultiheadAttention(
|
||||
embed_dim=embed_dimension,
|
||||
num_heads=num_heads,
|
||||
batch_first=True,
|
||||
dropout=dropout_p,
|
||||
)
|
||||
npt = pt.cuda().to(dtype)
|
||||
cpt = build_composite_mha_from_nn_mha(npt)
|
||||
x = torch.randn(
|
||||
batch_size,
|
||||
max_sequence_len,
|
||||
embed_dimension,
|
||||
dtype=dtype,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
|
||||
rand_fused_upward = cpt(x, x, x, mask).clone().detach()
|
||||
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
|
||||
rand_eager_upward = cpt(x, x, x, mask).clone().detach()
|
||||
|
||||
t0 = benchmark.Timer(
|
||||
stmt="forw_back_fused(cpt, (x,x,x,mask), rand_fused_upward)",
|
||||
globals={
|
||||
"forw_back_fused": forw_back_fused,
|
||||
"cpt": cpt,
|
||||
"x": x,
|
||||
"rand_fused_upward": rand_fused_upward,
|
||||
"mask": mask,
|
||||
},
|
||||
label=f"Fused SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} "
|
||||
f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}",
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
|
||||
t1 = benchmark.Timer(
|
||||
stmt="forw_back_eager(cpt, (x,x,x,mask), rand_eager_upward)",
|
||||
globals={
|
||||
"forw_back_eager": forw_back_eager,
|
||||
"cpt": cpt,
|
||||
"x": x,
|
||||
"rand_eager_upward": rand_eager_upward,
|
||||
"mask": mask,
|
||||
},
|
||||
label=f"Eager SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} "
|
||||
f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}",
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
|
||||
m0 = t0.blocked_autorange(min_run_time=min_run_time)
|
||||
m1 = t1.blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
print(m0)
|
||||
print(m1)
|
||||
|
||||
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
||||
|
||||
print("Profile for Fused".center(200, "-"))
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True):
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, with_stack=True
|
||||
) as prof:
|
||||
with record_function("Fused SDP forward and backward"):
|
||||
for _ in range(20):
|
||||
forw_back(cpt, (x, x, x, mask), rand_fused_upward)
|
||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
|
||||
|
||||
print("Profile for eager".center(200, "-"))
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False):
|
||||
with profile(
|
||||
activities=activities, record_shapes=False, with_stack=True
|
||||
) as prof:
|
||||
with record_function("Fused SDP forward and backward"):
|
||||
for _ in range(20):
|
||||
forw_back(cpt, (x, x, x, mask), rand_eager_upward)
|
||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
|
||||
|
||||
|
||||
def main():
|
||||
seed = 123
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
min_run_time = 10
|
||||
batch_size = 64
|
||||
num_heads = 32
|
||||
max_seq_len = 256
|
||||
embed_dim = 1024
|
||||
dtype = torch.bfloat16
|
||||
|
||||
print(
|
||||
f"Running timing for batch_size={batch_size} max_sequence_len={max_seq_len} "
|
||||
f"num_heads={num_heads} embed_dimension={embed_dim} dtype={dtype}"
|
||||
)
|
||||
run_timing(min_run_time, batch_size, embed_dim, num_heads, max_seq_len, dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -317,6 +317,9 @@ ALLOW_LIST = [
|
|||
("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)),
|
||||
("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)),
|
||||
("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)),
|
||||
("aten::_flash_scaled_dot_product_attention", datetime.date(2022, 12, 15)),
|
||||
("aten::_scaled_dot_product_attention_forward", datetime.date(2022, 12, 15)),
|
||||
("aten::_efficient_attention_backward", datetime.date(2022, 12, 15)),
|
||||
("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)),
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -401,6 +401,7 @@ class TestOperators(TestCase):
|
|||
skip('nn.functional.max_unpool2d'), # fails everywhere except on windows
|
||||
skip('nn.functional.max_unpool3d'), # fails everywhere except on mac
|
||||
xfail("native_batch_norm"),
|
||||
xfail('nn.functional._scaled_dot_product_attention', device_type='cuda'),
|
||||
|
||||
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
|
||||
|
||||
|
|
@ -555,6 +556,7 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.ctc_loss'), # Not Implemented
|
||||
xfail('native_layer_norm', ''), # Expected a proper Tensor but got None for argument #1 'other'
|
||||
xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides
|
||||
skip('nn.functional._scaled_dot_product_attention', device_type='cuda'),
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
# Mismatched elements: 1 / 15 (6.7%)
|
||||
# Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
|
||||
|
|
@ -649,7 +651,7 @@ class TestOperators(TestCase):
|
|||
skip("nn.functional.feature_alpha_dropout", "with_train"), # calls random op
|
||||
skip("nn.functional.fractional_max_pool2d"), # calls random op
|
||||
skip("nn.functional.fractional_max_pool3d"), # calls random op
|
||||
skip('nn.functional._scaled_dot_product_attention'), # randomness
|
||||
xfail('nn.functional._scaled_dot_product_attention'), # randomness
|
||||
# It looks like you're either (1) calling .item() on a Tensor or
|
||||
# (2) attempting to use a Tensor in some data-dependent control flow or
|
||||
# (3) encountering this error in PyTorch internals.
|
||||
|
|
@ -1126,6 +1128,7 @@ class TestOperators(TestCase):
|
|||
skip('nn.functional.rrelu'), # randomness
|
||||
skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
|
||||
skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness
|
||||
skip('nn.functional._scaled_dot_product_attention', device_type='cuda'),
|
||||
skip('nn.functional.alpha_dropout'), # randomness
|
||||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||
skip('to_sparse', ''), # non-dense output
|
||||
|
|
@ -1249,6 +1252,7 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for log_sigmoid_backward
|
||||
xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss
|
||||
xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward
|
||||
skip('nn.functional._scaled_dot_product_attention', device_type='cuda'),
|
||||
xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss
|
||||
skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why
|
||||
xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides
|
||||
|
|
@ -1369,7 +1373,7 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.dropout2d'), # calls random op
|
||||
xfail('nn.functional.dropout3d'), # calls random op
|
||||
xfail('nn.functional.dropout'), # calls random op
|
||||
skip('nn.functional._scaled_dot_product_attention'), # randomness
|
||||
xfail('nn.functional._scaled_dot_product_attention'), # randomness
|
||||
xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition
|
||||
xfail('nn.functional.alpha_dropout'), # calls randomn op
|
||||
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # calls random op
|
||||
|
|
|
|||
|
|
@ -294,7 +294,6 @@ CHECK_STRIDES_SKIPS = {
|
|||
aten._fft_c2r.default,
|
||||
aten._fft_r2c.default,
|
||||
aten._linalg_svd.default,
|
||||
aten._scaled_dot_product_attention_forward.default,
|
||||
aten.binary_cross_entropy.default,
|
||||
aten.complex.default,
|
||||
aten.copysign.Tensor,
|
||||
|
|
|
|||
|
|
@ -1059,6 +1059,11 @@ class TestTransformers(NNTestCase):
|
|||
|
||||
if fused_kernel == "flash":
|
||||
with sdp_kernel(enable_mem_efficient=False, enable_math=False):
|
||||
# TODO Flash for the nested path is currently not working due to cuda memory issues
|
||||
if type == "nested":
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
|
||||
query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False))
|
||||
return
|
||||
actual = torch.nn.functional._scaled_dot_product_attention(
|
||||
query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False)
|
||||
elif fused_kernel == "mem_efficient":
|
||||
|
|
@ -1097,28 +1102,73 @@ class TestTransformers(NNTestCase):
|
|||
|
||||
@unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system")
|
||||
@parametrize("contiguous_inputs", [True, False])
|
||||
def test_efficient_attention_gradcheck(self, contiguous_inputs: bool):
|
||||
def test_sdp_math_gradcheck(self, contiguous_inputs: bool):
|
||||
|
||||
batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64
|
||||
rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float16, requires_grad=True, packed=True)
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
|
||||
rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True)
|
||||
|
||||
qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim))
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
query = query.view(batch_size, -1, num_heads, head_dim)
|
||||
key = key.view(batch_size, -1, num_heads, head_dim)
|
||||
value = value.view(batch_size, -1, num_heads, head_dim)
|
||||
|
||||
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
if contiguous_inputs:
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# Normally we would transpose the inputs but the fused kernels expect
|
||||
# (batch, seq_len, num_heads, head_dim) bump the tolerance since we can only run kernel
|
||||
# in fp32
|
||||
assert gradcheck(lambda *args, **kwargs:
|
||||
wrapper_set_seed(torch.ops.aten._efficient_attention_forward, *args, **kwargs),
|
||||
(query, key, value, None, None, None, True, False), fast_mode=True, atol=8e-5, rtol=1e-3)
|
||||
with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False):
|
||||
assert gradcheck(lambda *args, **kwargs:
|
||||
wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs),
|
||||
(query, key, value, None, 0.0, False, False)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system")
|
||||
@parametrize("contiguous_inputs", [True, False])
|
||||
def test_sdp_fused_grad_against_math(self, contiguous_inputs: bool):
|
||||
batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16
|
||||
rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True)
|
||||
|
||||
qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim))
|
||||
qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_()
|
||||
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1)
|
||||
|
||||
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
if contiguous_inputs:
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
query_lp = query_lp.contiguous()
|
||||
key_lp = key_lp.contiguous()
|
||||
value_lp = value_lp.contiguous()
|
||||
|
||||
with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False):
|
||||
out, atten = torch.nn.functional._scaled_dot_product_attention(query, key, value, None, 0.0, False, False)
|
||||
|
||||
with sdp_kernel(enable_math=False, enable_mem_efficient=True, enable_flash=False):
|
||||
out_lp, atten_lp = torch.nn.functional._scaled_dot_product_attention(
|
||||
query_lp, key_lp, value_lp, None, 0.0, False, False)
|
||||
|
||||
rand_upward = torch.rand_like(out)
|
||||
rand_upward_lp = rand_upward.to(torch.float32)
|
||||
|
||||
out.backward(rand_upward)
|
||||
out_lp.backward(rand_upward_lp)
|
||||
|
||||
# Cast up and compare
|
||||
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
|
||||
|
||||
@parametrize("type", ["dense", "nested"])
|
||||
def test_fused_sdp_choice(self, type: str):
|
||||
|
|
@ -1144,7 +1194,7 @@ class TestTransformers(NNTestCase):
|
|||
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
if SM80OrLater:
|
||||
if SM80OrLater and not type == "nested":
|
||||
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION
|
||||
else:
|
||||
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION
|
||||
|
|
|
|||
|
|
@ -2613,9 +2613,13 @@
|
|||
nested_strides: non_differentiable
|
||||
|
||||
# Transformers
|
||||
- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor)
|
||||
output_differentiability: [True, False]
|
||||
query, key, value: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, result0, result1, is_causal)
|
||||
|
||||
- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor)
|
||||
output_differentiability: [True, False]
|
||||
query, key, value: _efficient_attention_backward(grad, query, key, value, result1, result0, causal)
|
||||
query, key, value: _efficient_attention_backward(grad, query, key, value, result0, result1, causal)
|
||||
|
||||
# fft
|
||||
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
|
||||
|
|
|
|||
|
|
@ -12009,16 +12009,21 @@ op_db: List[OpInfo] = [
|
|||
# This is only failing on Linux Bionic 3.10 Cuda 11.6
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes',
|
||||
device_type='cuda', active_if=_get_torch_cuda_version() >= (11, 6)),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples',
|
||||
device_type='cuda', dtypes=(torch.float32,)),
|
||||
# AssertionError: JIT Test does not execute any logic
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
# Doesn't support autocasting
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensorNonErroring', 'test_fake_autocast', device_type='cpu'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'),
|
||||
# Forward works for dtype=float64 which is the math path
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
|
||||
# No meta function
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
|
||||
DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', device_type='cuda'),
|
||||
DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'),),
|
||||
),
|
||||
UnaryUfuncInfo(
|
||||
|
|
|
|||
Loading…
Reference in a new issue