[cuDNN] cuDNN SDPA (Flash Attention) Backward (#122510)

#113713
currently passing trivial smoke tests but I just totally pattern-matched bits and pieces of the autograd defs

Will also collect benchmark data,

CC @drisspg

Co-authored-by: Eli Uriegas <1700823+seemethere@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122510
Approved by: https://github.com/drisspg
This commit is contained in:
eqy 2024-04-27 04:15:47 +00:00 committed by PyTorch MergeBot
parent 5944a53555
commit a866bfff45
8 changed files with 362 additions and 14 deletions

View file

@ -29,6 +29,30 @@ void run_cudnn_SDP_fprop(
false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
}
void run_cudnn_SDP_bprop(
int64_t b,
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
float scaling_factor,
bool is_causal,
float dropout_probability,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
Tensor& dQ,
Tensor& dK,
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset) {
TORCH_CHECK(
false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
}
} // namespace native
} // namespace at
@ -73,6 +97,22 @@ using graph_and_tensors = std::tuple<
std::shared_ptr<fe::graph::Tensor_attributes> // Stats
>;
using graph_and_tensors_backward = std::tuple<
std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
std::shared_ptr<fe::graph::Tensor_attributes>, // Offset,
std::shared_ptr<fe::graph::Tensor_attributes>, // O,
std::shared_ptr<fe::graph::Tensor_attributes>, // dO,
std::shared_ptr<fe::graph::Tensor_attributes>, // stats,
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ,
std::shared_ptr<fe::graph::Tensor_attributes>, // dK,,
std::shared_ptr<fe::graph::Tensor_attributes> // dV,
>;
#define MAX_MHA_DIM 4
struct MHAParams {
@ -178,8 +218,7 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
template <typename T, typename KeyType>
struct MHAGraphCache {
std::unordered_map<KeyType, graph_and_tensors, ParamsWrapperHash<KeyType>>
engine_cache;
std::unordered_map<KeyType, T, ParamsWrapperHash<KeyType>> engine_cache;
// no mutexes here as caches are now thread local for v8, can also return a
// pointer to the Execution Plan if we know it will not be invalidated by
@ -202,6 +241,8 @@ struct MHAGraphCache {
// be thread safe across all engines see Limitations in
// https://docs.nvidia.com/deeplearning/cudnn/release-notes/index.html
thread_local MHAGraphCache<graph_and_tensors, MHACacheKeyWrapper> mhagraphcache;
thread_local MHAGraphCache<graph_and_tensors_backward, MHACacheKeyWrapper>
mhagraphbackwardcache;
auto build_graph_and_tensors(
int64_t b,
@ -227,10 +268,12 @@ auto build_graph_and_tensors(
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
mha_graph->set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
@ -254,7 +297,7 @@ auto build_graph_and_tensors(
params.v_stride.begin(), params.v_stride.end())));
auto attn_scale =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("attn_scale")
.set_name("Attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
@ -276,7 +319,7 @@ auto build_graph_and_tensors(
.set_data_type(fe::DataType_t::INT32));
auto scaled_dot_product_flash_attention_options =
fe::graph::SDPA_attributes()
.set_name("flash_attention")
.set_name("CUDNN_SDPA")
.set_is_inference(return_softmaxstats == false)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale)
@ -287,12 +330,12 @@ auto build_graph_and_tensors(
}
auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_q")
.set_name("Seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("seq_kv")
.set_name("Seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
@ -324,7 +367,146 @@ auto build_graph_and_tensors(
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
return std::make_tuple(
mha_graph, Q, K, V, attn_scale, seed, offset, O, Stats);
std::move(mha_graph),
std::move(Q),
std::move(K),
std::move(V),
std::move(attn_scale),
std::move(seed),
std::move(offset),
std::move(O),
std::move(Stats));
}
auto build_graph_and_tensors_backward(
int64_t b,
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
float scaling_factor,
bool is_causal,
float dropout_probability,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
Tensor& dQ,
Tensor& dK,
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset,
cudnnHandle_t& handle,
MHAParams& params) {
auto dtype = fe::DataType_t::HALF;
if (q.scalar_type() == kBFloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
mha_graph->set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim(std::vector<int64_t>(q.sizes().begin(), q.sizes().end()))
.set_stride(
std::vector<int64_t>(q.strides().begin(), q.strides().end())));
auto K = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("K")
.set_dim(std::vector<int64_t>(k.sizes().begin(), k.sizes().end()))
.set_stride(
std::vector<int64_t>(k.strides().begin(), k.strides().end())));
auto V = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("V")
.set_dim(std::vector<int64_t>(v.sizes().begin(), v.sizes().end()))
.set_stride(
std::vector<int64_t>(v.strides().begin(), v.strides().end())));
auto attn_scale =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Offset")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto O = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("O")
.set_dim(std::vector<int64_t>(o.sizes().begin(), o.sizes().end()))
.set_stride(
std::vector<int64_t>(o.strides().begin(), o.strides().end())));
auto STATS = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Stats")
.set_dim(std::vector<int64_t>(
softmaxstats.sizes().begin(), softmaxstats.sizes().end()))
.set_stride(std::vector<int64_t>(
softmaxstats.strides().begin(), softmaxstats.strides().end()))
.set_data_type(fe::DataType_t::FLOAT));
auto DO = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("DO")
.set_dim(std::vector<int64_t>(dO.sizes().begin(), dO.sizes().end()))
.set_stride(
std::vector<int64_t>(dO.strides().begin(), dO.strides().end())));
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("CUDNN_SDPA_BACKWARD")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
if (dropout_probability != 0.0f) {
sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset);
}
auto [DQ, DK, DV] =
mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options);
DQ->set_output(true)
.set_dim(std::vector<int64_t>(dQ.sizes().begin(), dQ.sizes().end()))
.set_stride(
std::vector<int64_t>(dQ.strides().begin(), dQ.strides().end()));
DK->set_output(true)
.set_dim(std::vector<int64_t>(dK.sizes().begin(), dK.sizes().end()))
.set_stride(
std::vector<int64_t>(dK.strides().begin(), dK.strides().end()));
DV->set_output(true)
.set_dim(std::vector<int64_t>(dV.sizes().begin(), dV.sizes().end()))
.set_stride(
std::vector<int64_t>(dV.strides().begin(), dV.strides().end()));
AT_CUDNN_FRONTEND_CHECK(mha_graph->validate());
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle));
AT_CUDNN_FRONTEND_CHECK(
mha_graph->create_execution_plans({fe::HeurMode_t::A}));
AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle));
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
return std::make_tuple(
std::move(mha_graph),
std::move(Q),
std::move(K),
std::move(V),
std::move(attn_scale),
std::move(Seed),
std::move(Offset),
std::move(O),
std::move(DO),
std::move(STATS),
std::move(DQ),
std::move(DK),
std::move(DV));
}
void run_cudnn_SDP_fprop(
@ -407,11 +589,92 @@ void run_cudnn_SDP_fprop(
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_INTERNAL_ASSERT(
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
mhagraphcache.update(key, graph_and_tensors_values);
}
void run_cudnn_SDP_bprop(
int64_t b,
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
float scaling_factor,
bool is_causal,
float dropout_probability,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
Tensor& dQ,
Tensor& dK,
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset) {
cudnnHandle_t handle = getCudnnHandle();
auto key = MHACacheKeyWrapper(
b, h, s_q, s_kv, d, q, k, v, dropout_probability, is_causal, true);
auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key);
graph_and_tensors_backward graph_and_tensors_backward_values;
if (graph_and_tensors_backward_ptr) {
graph_and_tensors_backward_values = *graph_and_tensors_backward_ptr;
} else {
graph_and_tensors_backward_values = build_graph_and_tensors_backward(
b,
h,
s_q,
s_kv,
d,
scaling_factor,
is_causal,
dropout_probability,
q,
k,
v,
o,
dO,
softmaxstats,
dQ,
dK,
dV,
dropoutseed,
dropoutoffset,
handle,
key.pod);
}
auto
[mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] =
graph_and_tensors_backward_values;
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
variant_pack = {// inputs
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{O, o.data_ptr()},
{Do, dO.data_ptr()},
{Stats, softmaxstats.data_ptr()},
// outputs
{Dq, dQ.data_ptr()},
{Dk, dK.data_ptr()},
{Dv, dV.data_ptr()},
// pass by value
{attn_scale, &scaling_factor}};
if (dropout_probability != 0.0f) {
variant_pack[Seed] = dropoutseed.data_ptr();
variant_pack[Offset] = dropoutoffset.data_ptr();
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(!workspace_size || workspace_ptr.get());
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
mhagraphbackwardcache.update(key, graph_and_tensors_backward_values);
}
} // namespace native
} // namespace at

View file

@ -21,5 +21,27 @@ void run_cudnn_SDP_fprop(
Tensor& o,
Tensor& dropoutseed,
Tensor& dropoutoffset);
}
void run_cudnn_SDP_bprop(
int64_t b,
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
float scaling_factor,
bool is_causal,
float dropout_probability,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
Tensor& dQ,
Tensor& dK,
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset);
} // namespace native
} // namespace at

View file

@ -14700,11 +14700,16 @@
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
tags: nondeterministic_seeded
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CUDA: _scaled_dot_product_cudnn_attention_cuda
tags: nondeterministic_seeded
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
dispatch:
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
tags: nondeterministic_seeded
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
variants: function
dispatch:

View file

@ -726,7 +726,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed, philox_offset, debug_attn_mask);
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_cuda(
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
@ -773,7 +773,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
cudnn_seed/*Tensor dropoutseed*/,
cudnn_offset/*Tensor dropoutoffset*/);
return std::make_tuple(attention, log_sumexp, cudnn_seed, cudnn_offset);
return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed, cudnn_offset, Tensor());
}
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(

View file

@ -43,6 +43,12 @@
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
#endif
#ifdef __HIP_PLATFORM_AMD__
#include <ATen/native/cudnn/hip/MHA.h>
#else
#include <ATen/native/cudnn/MHA.h>
#endif
namespace at::native {
std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
@ -117,7 +123,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
determinisitic,
philox_seed,
philox_offset);
return std::make_tuple(dQuery, dKey, dValue);
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue));
} else {
// Dense forward
auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_bwd(
@ -146,6 +152,52 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
return std::make_tuple(Tensor(), Tensor(), Tensor());
}
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda(
const Tensor& grad_out,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const Tensor& out,
const Tensor& logsumexp,
const Tensor& cumulative_sequence_length_q,
const Tensor& cumulative_sequence_length_k,
const int64_t max_seqlen_batch_q,
const int64_t max_seqlen_batch_k,
double dropout_p,
bool is_causal,
const Tensor& philox_seed,
const Tensor& philox_offset,
c10::optional<double> scale) {
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim = query.size(3);
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
auto dq = at::empty_like(query);
auto dk = at::empty_like(key);
auto dv = at::empty_like(value);
run_cudnn_SDP_bprop(batch_size /*int64_t b*/,
num_heads /*int64_t h*/,
max_seqlen_batch_q /*int64_t s_q*/,
max_seqlen_batch_k /*int64_t s_kv*/,
head_dim /*int64_t d*/,
softmax_scale /*float scaling_factor*/,
is_causal /*bool is_causal*/,
dropout_p /*float dropout_probability*/,
query /*const Tensor& q*/,
key /*const Tensor& k*/,
value /*const Tensor& v*/,
out /*const Tensor& o*/,
grad_out/*const Tensor& dO*/,
logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/,
dq/*Tensor& dQ*/,
dk/*Tensor& dK*/,
dv/*Tensor& dV*/,
philox_seed/*Tensor& dropoutseed*/,
philox_offset/*Tensor& dropoutoffset*/);
return std::make_tuple(dq, dk, dv);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_efficient_attention_backward(

View file

@ -487,6 +487,7 @@ aten::_resize_output_
aten::_sample_dirichlet
aten::_sample_dirichlet.out
aten::_scaled_dot_product_cudnn_attention
aten::_scaled_dot_product_cudnn_attention_backward
aten::_scaled_dot_product_efficient_attention
aten::_scaled_dot_product_efficient_attention_backward
aten::_scaled_dot_product_flash_attention

View file

@ -111,6 +111,7 @@ ALLOW_LIST = [
("prims::.*", datetime.date(9999, 1, 1)),
("aten::_flash_attention_forward", datetime.date(2023, 12, 30)),
("aten::_flash_attention_backward", datetime.date(2023, 12, 30)),
("aten::_scaled_dot_product_cudnn_attention", datetime.date(9999, 1, 1)),
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
# BetterTransformer 1.0 internal operators
("aten::_transformer_decoder_only_layer_fwd", datetime.date(9999, 1, 1)),

View file

@ -2823,6 +2823,10 @@
output_differentiability: [True, False, False, False, False, False]
query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale)
- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
output_differentiability: [True, False, False, False, False, False, False, False, False]
query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
# fft
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back()))