mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[cuDNN] cuDNN SDPA (Flash Attention) Backward (#122510)
#113713 currently passing trivial smoke tests but I just totally pattern-matched bits and pieces of the autograd defs Will also collect benchmark data, CC @drisspg Co-authored-by: Eli Uriegas <1700823+seemethere@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/122510 Approved by: https://github.com/drisspg
This commit is contained in:
parent
5944a53555
commit
a866bfff45
8 changed files with 362 additions and 14 deletions
|
|
@ -29,6 +29,30 @@ void run_cudnn_SDP_fprop(
|
|||
false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
|
||||
}
|
||||
|
||||
void run_cudnn_SDP_bprop(
|
||||
int64_t b,
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const Tensor& o,
|
||||
const Tensor& dO,
|
||||
const Tensor& softmaxstats,
|
||||
Tensor& dQ,
|
||||
Tensor& dK,
|
||||
Tensor& dV,
|
||||
const Tensor& dropoutseed,
|
||||
const Tensor& dropoutoffset) {
|
||||
TORCH_CHECK(
|
||||
false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
|
|
@ -73,6 +97,22 @@ using graph_and_tensors = std::tuple<
|
|||
std::shared_ptr<fe::graph::Tensor_attributes> // Stats
|
||||
>;
|
||||
|
||||
using graph_and_tensors_backward = std::tuple<
|
||||
std::shared_ptr<fe::graph::Graph>,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // Offset,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // O,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // dO,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // stats,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes>, // dK,,
|
||||
std::shared_ptr<fe::graph::Tensor_attributes> // dV,
|
||||
>;
|
||||
|
||||
#define MAX_MHA_DIM 4
|
||||
|
||||
struct MHAParams {
|
||||
|
|
@ -178,8 +218,7 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
|
|||
|
||||
template <typename T, typename KeyType>
|
||||
struct MHAGraphCache {
|
||||
std::unordered_map<KeyType, graph_and_tensors, ParamsWrapperHash<KeyType>>
|
||||
engine_cache;
|
||||
std::unordered_map<KeyType, T, ParamsWrapperHash<KeyType>> engine_cache;
|
||||
|
||||
// no mutexes here as caches are now thread local for v8, can also return a
|
||||
// pointer to the Execution Plan if we know it will not be invalidated by
|
||||
|
|
@ -202,6 +241,8 @@ struct MHAGraphCache {
|
|||
// be thread safe across all engines see Limitations in
|
||||
// https://docs.nvidia.com/deeplearning/cudnn/release-notes/index.html
|
||||
thread_local MHAGraphCache<graph_and_tensors, MHACacheKeyWrapper> mhagraphcache;
|
||||
thread_local MHAGraphCache<graph_and_tensors_backward, MHACacheKeyWrapper>
|
||||
mhagraphbackwardcache;
|
||||
|
||||
auto build_graph_and_tensors(
|
||||
int64_t b,
|
||||
|
|
@ -227,10 +268,12 @@ auto build_graph_and_tensors(
|
|||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
auto mha_graph = std::make_shared<fe::graph::Graph>();
|
||||
// We're baking in float accumulation and scale types
|
||||
// in theory the graph may support other types, but they
|
||||
// have not been tested
|
||||
mha_graph->set_io_data_type(dtype)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
|
||||
auto Q = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("Q")
|
||||
|
|
@ -254,7 +297,7 @@ auto build_graph_and_tensors(
|
|||
params.v_stride.begin(), params.v_stride.end())));
|
||||
auto attn_scale =
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("attn_scale")
|
||||
.set_name("Attn_scale")
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
|
|
@ -276,7 +319,7 @@ auto build_graph_and_tensors(
|
|||
.set_data_type(fe::DataType_t::INT32));
|
||||
auto scaled_dot_product_flash_attention_options =
|
||||
fe::graph::SDPA_attributes()
|
||||
.set_name("flash_attention")
|
||||
.set_name("CUDNN_SDPA")
|
||||
.set_is_inference(return_softmaxstats == false)
|
||||
.set_causal_mask(is_causal)
|
||||
.set_attn_scale(attn_scale)
|
||||
|
|
@ -287,12 +330,12 @@ auto build_graph_and_tensors(
|
|||
}
|
||||
|
||||
auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("seq_q")
|
||||
.set_name("Seq_q")
|
||||
.set_dim({b, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_data_type(fe::DataType_t::INT32));
|
||||
auto seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("seq_kv")
|
||||
.set_name("Seq_kv")
|
||||
.set_dim({b, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_data_type(fe::DataType_t::INT32));
|
||||
|
|
@ -324,7 +367,146 @@ auto build_graph_and_tensors(
|
|||
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
|
||||
|
||||
return std::make_tuple(
|
||||
mha_graph, Q, K, V, attn_scale, seed, offset, O, Stats);
|
||||
std::move(mha_graph),
|
||||
std::move(Q),
|
||||
std::move(K),
|
||||
std::move(V),
|
||||
std::move(attn_scale),
|
||||
std::move(seed),
|
||||
std::move(offset),
|
||||
std::move(O),
|
||||
std::move(Stats));
|
||||
}
|
||||
|
||||
auto build_graph_and_tensors_backward(
|
||||
int64_t b,
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const Tensor& o,
|
||||
const Tensor& dO,
|
||||
const Tensor& softmaxstats,
|
||||
Tensor& dQ,
|
||||
Tensor& dK,
|
||||
Tensor& dV,
|
||||
const Tensor& dropoutseed,
|
||||
const Tensor& dropoutoffset,
|
||||
cudnnHandle_t& handle,
|
||||
MHAParams& params) {
|
||||
auto dtype = fe::DataType_t::HALF;
|
||||
if (q.scalar_type() == kBFloat16) {
|
||||
dtype = fe::DataType_t::BFLOAT16;
|
||||
}
|
||||
auto mha_graph = std::make_shared<fe::graph::Graph>();
|
||||
// We're baking in float accumulation and scale types
|
||||
// in theory the graph may support other types, but they
|
||||
// have not been tested
|
||||
mha_graph->set_io_data_type(dtype)
|
||||
.set_intermediate_data_type(fe::DataType_t::FLOAT)
|
||||
.set_compute_data_type(fe::DataType_t::FLOAT);
|
||||
auto Q = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("Q")
|
||||
.set_dim(std::vector<int64_t>(q.sizes().begin(), q.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(q.strides().begin(), q.strides().end())));
|
||||
auto K = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("K")
|
||||
.set_dim(std::vector<int64_t>(k.sizes().begin(), k.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(k.strides().begin(), k.strides().end())));
|
||||
auto V = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("V")
|
||||
.set_dim(std::vector<int64_t>(v.sizes().begin(), v.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(v.strides().begin(), v.strides().end())));
|
||||
auto attn_scale =
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Attn_scale")
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_is_pass_by_value(true)
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Seed")
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_data_type(fe::DataType_t::INT32));
|
||||
auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_name("Offset")
|
||||
.set_dim({1, 1, 1, 1})
|
||||
.set_stride({1, 1, 1, 1})
|
||||
.set_data_type(fe::DataType_t::INT32));
|
||||
auto O = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("O")
|
||||
.set_dim(std::vector<int64_t>(o.sizes().begin(), o.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(o.strides().begin(), o.strides().end())));
|
||||
auto STATS = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("Stats")
|
||||
.set_dim(std::vector<int64_t>(
|
||||
softmaxstats.sizes().begin(), softmaxstats.sizes().end()))
|
||||
.set_stride(std::vector<int64_t>(
|
||||
softmaxstats.strides().begin(), softmaxstats.strides().end()))
|
||||
.set_data_type(fe::DataType_t::FLOAT));
|
||||
auto DO = mha_graph->tensor(
|
||||
fe::graph::Tensor_attributes()
|
||||
.set_name("DO")
|
||||
.set_dim(std::vector<int64_t>(dO.sizes().begin(), dO.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(dO.strides().begin(), dO.strides().end())));
|
||||
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
|
||||
.set_name("CUDNN_SDPA_BACKWARD")
|
||||
.set_causal_mask(is_causal)
|
||||
.set_attn_scale(attn_scale);
|
||||
if (dropout_probability != 0.0f) {
|
||||
sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset);
|
||||
}
|
||||
auto [DQ, DK, DV] =
|
||||
mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options);
|
||||
DQ->set_output(true)
|
||||
.set_dim(std::vector<int64_t>(dQ.sizes().begin(), dQ.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(dQ.strides().begin(), dQ.strides().end()));
|
||||
DK->set_output(true)
|
||||
.set_dim(std::vector<int64_t>(dK.sizes().begin(), dK.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(dK.strides().begin(), dK.strides().end()));
|
||||
DV->set_output(true)
|
||||
.set_dim(std::vector<int64_t>(dV.sizes().begin(), dV.sizes().end()))
|
||||
.set_stride(
|
||||
std::vector<int64_t>(dV.strides().begin(), dV.strides().end()));
|
||||
AT_CUDNN_FRONTEND_CHECK(mha_graph->validate());
|
||||
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle));
|
||||
AT_CUDNN_FRONTEND_CHECK(
|
||||
mha_graph->create_execution_plans({fe::HeurMode_t::A}));
|
||||
AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle));
|
||||
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
|
||||
return std::make_tuple(
|
||||
std::move(mha_graph),
|
||||
std::move(Q),
|
||||
std::move(K),
|
||||
std::move(V),
|
||||
std::move(attn_scale),
|
||||
std::move(Seed),
|
||||
std::move(Offset),
|
||||
std::move(O),
|
||||
std::move(DO),
|
||||
std::move(STATS),
|
||||
std::move(DQ),
|
||||
std::move(DK),
|
||||
std::move(DV));
|
||||
}
|
||||
|
||||
void run_cudnn_SDP_fprop(
|
||||
|
|
@ -407,11 +589,92 @@ void run_cudnn_SDP_fprop(
|
|||
auto workspace_size = mha_graph->get_workspace_size();
|
||||
auto workspace_ptr =
|
||||
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
TORCH_CHECK(
|
||||
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
|
||||
mhagraphcache.update(key, graph_and_tensors_values);
|
||||
}
|
||||
|
||||
void run_cudnn_SDP_bprop(
|
||||
int64_t b,
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const Tensor& o,
|
||||
const Tensor& dO,
|
||||
const Tensor& softmaxstats,
|
||||
Tensor& dQ,
|
||||
Tensor& dK,
|
||||
Tensor& dV,
|
||||
const Tensor& dropoutseed,
|
||||
const Tensor& dropoutoffset) {
|
||||
cudnnHandle_t handle = getCudnnHandle();
|
||||
auto key = MHACacheKeyWrapper(
|
||||
b, h, s_q, s_kv, d, q, k, v, dropout_probability, is_causal, true);
|
||||
auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key);
|
||||
graph_and_tensors_backward graph_and_tensors_backward_values;
|
||||
if (graph_and_tensors_backward_ptr) {
|
||||
graph_and_tensors_backward_values = *graph_and_tensors_backward_ptr;
|
||||
} else {
|
||||
graph_and_tensors_backward_values = build_graph_and_tensors_backward(
|
||||
b,
|
||||
h,
|
||||
s_q,
|
||||
s_kv,
|
||||
d,
|
||||
scaling_factor,
|
||||
is_causal,
|
||||
dropout_probability,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
dO,
|
||||
softmaxstats,
|
||||
dQ,
|
||||
dK,
|
||||
dV,
|
||||
dropoutseed,
|
||||
dropoutoffset,
|
||||
handle,
|
||||
key.pod);
|
||||
}
|
||||
auto
|
||||
[mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] =
|
||||
graph_and_tensors_backward_values;
|
||||
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
|
||||
variant_pack = {// inputs
|
||||
{Q, q.data_ptr()},
|
||||
{K, k.data_ptr()},
|
||||
{V, v.data_ptr()},
|
||||
{O, o.data_ptr()},
|
||||
{Do, dO.data_ptr()},
|
||||
{Stats, softmaxstats.data_ptr()},
|
||||
// outputs
|
||||
{Dq, dQ.data_ptr()},
|
||||
{Dk, dK.data_ptr()},
|
||||
{Dv, dV.data_ptr()},
|
||||
// pass by value
|
||||
{attn_scale, &scaling_factor}};
|
||||
if (dropout_probability != 0.0f) {
|
||||
variant_pack[Seed] = dropoutseed.data_ptr();
|
||||
variant_pack[Offset] = dropoutoffset.data_ptr();
|
||||
}
|
||||
auto workspace_size = mha_graph->get_workspace_size();
|
||||
auto workspace_ptr =
|
||||
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
|
||||
TORCH_CHECK(!workspace_size || workspace_ptr.get());
|
||||
TORCH_CHECK(
|
||||
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
|
||||
mhagraphbackwardcache.update(key, graph_and_tensors_backward_values);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
|
|
|
|||
|
|
@ -21,5 +21,27 @@ void run_cudnn_SDP_fprop(
|
|||
Tensor& o,
|
||||
Tensor& dropoutseed,
|
||||
Tensor& dropoutoffset);
|
||||
}
|
||||
|
||||
void run_cudnn_SDP_bprop(
|
||||
int64_t b,
|
||||
int64_t h,
|
||||
int64_t s_q,
|
||||
int64_t s_kv,
|
||||
int64_t d,
|
||||
float scaling_factor,
|
||||
bool is_causal,
|
||||
float dropout_probability,
|
||||
const Tensor& q,
|
||||
const Tensor& k,
|
||||
const Tensor& v,
|
||||
const Tensor& o,
|
||||
const Tensor& dO,
|
||||
const Tensor& softmaxstats,
|
||||
Tensor& dQ,
|
||||
Tensor& dK,
|
||||
Tensor& dV,
|
||||
const Tensor& dropoutseed,
|
||||
const Tensor& dropoutoffset);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -14700,11 +14700,16 @@
|
|||
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
|
||||
- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_cudnn_attention_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
variants: function
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -726,7 +726,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Ten
|
|||
return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed, philox_offset, debug_attn_mask);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_cuda(
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_cuda(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
|
|
@ -773,7 +773,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
|
|||
cudnn_seed/*Tensor dropoutseed*/,
|
||||
cudnn_offset/*Tensor dropoutoffset*/);
|
||||
|
||||
return std::make_tuple(attention, log_sumexp, cudnn_seed, cudnn_offset);
|
||||
return std::make_tuple(attention, log_sumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, cudnn_seed, cudnn_offset, Tensor());
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
|
||||
|
|
|
|||
|
|
@ -43,6 +43,12 @@
|
|||
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
#include <ATen/native/cudnn/hip/MHA.h>
|
||||
#else
|
||||
#include <ATen/native/cudnn/MHA.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
|
|
@ -117,7 +123,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
|||
determinisitic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
return std::make_tuple(dQuery, dKey, dValue);
|
||||
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue));
|
||||
} else {
|
||||
// Dense forward
|
||||
auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_bwd(
|
||||
|
|
@ -146,6 +152,52 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
|||
return std::make_tuple(Tensor(), Tensor(), Tensor());
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_cuda(
|
||||
const Tensor& grad_out,
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
const Tensor& out,
|
||||
const Tensor& logsumexp,
|
||||
const Tensor& cumulative_sequence_length_q,
|
||||
const Tensor& cumulative_sequence_length_k,
|
||||
const int64_t max_seqlen_batch_q,
|
||||
const int64_t max_seqlen_batch_k,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
const Tensor& philox_seed,
|
||||
const Tensor& philox_offset,
|
||||
c10::optional<double> scale) {
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t head_dim = query.size(3);
|
||||
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
|
||||
auto dq = at::empty_like(query);
|
||||
auto dk = at::empty_like(key);
|
||||
auto dv = at::empty_like(value);
|
||||
run_cudnn_SDP_bprop(batch_size /*int64_t b*/,
|
||||
num_heads /*int64_t h*/,
|
||||
max_seqlen_batch_q /*int64_t s_q*/,
|
||||
max_seqlen_batch_k /*int64_t s_kv*/,
|
||||
head_dim /*int64_t d*/,
|
||||
softmax_scale /*float scaling_factor*/,
|
||||
is_causal /*bool is_causal*/,
|
||||
dropout_p /*float dropout_probability*/,
|
||||
query /*const Tensor& q*/,
|
||||
key /*const Tensor& k*/,
|
||||
value /*const Tensor& v*/,
|
||||
out /*const Tensor& o*/,
|
||||
grad_out/*const Tensor& dO*/,
|
||||
logsumexp.unsqueeze(-1)/*const Tensor& softmaxstats*/,
|
||||
dq/*Tensor& dQ*/,
|
||||
dk/*Tensor& dK*/,
|
||||
dv/*Tensor& dV*/,
|
||||
philox_seed/*Tensor& dropoutseed*/,
|
||||
philox_offset/*Tensor& dropoutoffset*/);
|
||||
return std::make_tuple(dq, dk, dv);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
_efficient_attention_backward(
|
||||
|
|
|
|||
|
|
@ -487,6 +487,7 @@ aten::_resize_output_
|
|||
aten::_sample_dirichlet
|
||||
aten::_sample_dirichlet.out
|
||||
aten::_scaled_dot_product_cudnn_attention
|
||||
aten::_scaled_dot_product_cudnn_attention_backward
|
||||
aten::_scaled_dot_product_efficient_attention
|
||||
aten::_scaled_dot_product_efficient_attention_backward
|
||||
aten::_scaled_dot_product_flash_attention
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ ALLOW_LIST = [
|
|||
("prims::.*", datetime.date(9999, 1, 1)),
|
||||
("aten::_flash_attention_forward", datetime.date(2023, 12, 30)),
|
||||
("aten::_flash_attention_backward", datetime.date(2023, 12, 30)),
|
||||
("aten::_scaled_dot_product_cudnn_attention", datetime.date(9999, 1, 1)),
|
||||
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
|
||||
# BetterTransformer 1.0 internal operators
|
||||
("aten::_transformer_decoder_only_layer_fwd", datetime.date(9999, 1, 1)),
|
||||
|
|
|
|||
|
|
@ -2823,6 +2823,10 @@
|
|||
output_differentiability: [True, False, False, False, False, False]
|
||||
query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale)
|
||||
|
||||
- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
output_differentiability: [True, False, False, False, False, False, False, False, False]
|
||||
query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
|
||||
# fft
|
||||
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
|
||||
self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back()))
|
||||
|
|
|
|||
Loading…
Reference in a new issue