[Sync torch_FA2 and FA2 flash_api] + [Expose seqused_k & alibi_slopes arguments] (#126520)

1. **Expose seqused_k & alibi_slopes arguments**:
- This can be used when your sequence length k is not the full extent of the tensor. This is useful for kv cache scenarios and was not previously supported in the FA2 TORCH integration. We need these arguments for external xformers lib call to the _flash_attention_forward API.
Before:
```
  std::optional<Tensor> seqused_k = c10::nullopt;
  std::optional<Tensor> alibi_slopes = c10::nullopt;
```
After:
```
_flash_attention_forward(...
    std::optional<Tensor>& seqused_k,
    std::optional<Tensor>& alibi_slopes,
```

2. There is a difference between the **TORCH_FA2_flash_api:mha_fwd** and **FA2_flash_api:mha_fwd** (same for **mha_varlen_fwd**) at the query transposition (GQA) step.

The **CHECK_SHAPE** is applied on the original query vs the reshaped query. This causes an error (because of the shape constraint) for such inputs:
```
q = torch.randn([7, 1, 4, 256], dtype=torch.bfloat16, device='cuda')
k = torch.randn([7, 51, 1, 256], dtype=torch.bfloat16, device='cuda')
v = torch.randn([7, 51, 1, 256], dtype=torch.bfloat16, device='cuda')
```

![image](https://github.com/pytorch/pytorch/assets/927999/77ea6bf6-b6e9-4f3f-96a9-8d952956ddd9)

- i've modified the code as little as possible, but if you prefer a more verbose change like the following, dont hesitate to tell me:
```
at::Tensor swapped_q = seqlenq_ngroups_swapped
    ? q.reshape({batch_size, num_heads_k, num_heads / num_heads_k, head_size_og}).transpose(1, 2)
    : q;

if (seqlenq_ngroups_swapped) {
    seqlen_q = num_heads / num_heads_k;
    num_heads = num_heads_k;
}

CHECK_SHAPE(swapped_q, batch_size, seqlen_q, num_heads, head_size_og);
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126520
Approved by: https://github.com/drisspg
This commit is contained in:
Valeriu 2024-05-29 11:54:44 +00:00 committed by PyTorch MergeBot
parent dae33a4961
commit 02b1cdab23
7 changed files with 19 additions and 13 deletions

View file

@ -14720,7 +14720,7 @@
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
tags: nondeterministic_seeded
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
variants: function
dispatch:
CUDA: _flash_attention_forward

View file

@ -265,7 +265,9 @@ _scaled_dot_product_flash_attention_nestedtensor_cuda(
dropout_p,
is_causal,
return_debug_mask,
scale);
scale,
c10::nullopt,
c10::nullopt);
// Reshape output to convert nnz to batch_size and seq_len
attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
return std::make_tuple(

View file

@ -719,7 +719,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
dropout_p,
is_causal,
return_debug_mask,
scale);
scale,
c10::nullopt,
c10::nullopt);
// Reshape output to convert nnz to batch_size and seq_len
Tensor attention = output.transpose(1,2);
@ -843,16 +845,17 @@ _flash_attention_forward(
bool return_debug_mask,
std::optional<double> scale,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right) {
std::optional<int64_t> window_size_right,
const std::optional<Tensor>& _seqused_k,
const std::optional<Tensor>& _alibi_slopes
) {
#if defined(USE_FLASH_ATTENTION)
const auto softmax_scale =
sdp::calculate_scale(query, scale).as_float_unchecked();
std::optional<Tensor> out = c10::nullopt;
// This can be used when your sequence length k is not the full extent
// of the tensor. This is useful for kv cache scenarios but for now
// we will not support in this PR.
std::optional<Tensor> seqused_k = c10::nullopt;
std::optional<Tensor> alibi_slopes = c10::nullopt;
std::optional<Tensor> seqused_k = _seqused_k;
std::optional<Tensor> alibi_slopes = _alibi_slopes;
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;

View file

@ -410,7 +410,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
num_heads = num_heads_k;
}
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(temp_q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
@ -612,7 +612,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);

View file

@ -2816,7 +2816,7 @@
output_differentiability: [True, False]
query, key, value: _scaled_dot_product_flash_attention_for_cpu_backward(grad, query, key, value, output, logsumexp, dropout_p, is_causal, attn_mask, scale)
- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
output_differentiability: [True, False, False, False, False]
query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale, window_size_left, window_size_right)

View file

@ -801,6 +801,7 @@ def meta__flash_attention_forward(fake_mode, func, *args, **kwargs):
max_k = kwargs["max_k"]
return_debug_mask = kwargs["return_debug_mask"]
# unused: value, dropout_p, is_causal, scale
# unused: seqused_k, alibi_slopes, window_size_left, window_size_right
def convert_tensor(t, device):
return FakeTensor(fake_mode, t, device)

View file

@ -29,7 +29,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_per_sample_weigh
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__pdist_backward(AtenTensorHandle grad, AtenTensorHandle self, double p, AtenTensorHandle pdist, AtenTensorHandle* ret0);