pytorch/docs/source
Andy Lugo 500d02921b [ROCm] CK Flash Attention Backend (#138947)
Replaces https://github.com/ROCm/pytorch/pull/1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling `torch.backends.cuda.preferred_rocm_fa_library("ck")`. Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via `USE_FLASH_ATTENTION`) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138947
Approved by: https://github.com/pruthvistony, https://github.com/xw285cornell, https://github.com/leitian

Co-authored-by: Xiaodong Wang <xw285@cornell.edu>
2024-12-17 02:18:07 +00:00
..
_static Fix search icon (#142808) 2024-12-12 16:09:30 +00:00
_templates Add an option for classic search (#142018) 2024-12-06 01:24:52 +00:00
community Update maintainers for inductor and x86 CPU (#136839) 2024-10-11 07:24:07 +00:00
elastic
notes Fix typo in Reproducibility docs (#141341) 2024-11-26 16:53:26 +00:00
rpc
scripts
accelerator.rst [BE][accelerator] formalize API name {current,set}_device_{idx => index} (#140542) 2024-12-12 10:53:48 +00:00
amp.rst Update document for autocast on CPU (#135299) 2024-09-13 09:11:47 +00:00
autograd.rst
backends.rst [ROCm] CK Flash Attention Backend (#138947) 2024-12-17 02:18:07 +00:00
benchmark_utils.rst
bottleneck.rst
checkpoint.rst
complex_numbers.rst
cond.rst
conf.py debug handler maintain through decomposition (#141612) 2024-12-12 12:26:45 +00:00
config_mod.rst
cpp_extension.rst
cpp_index.rst
cpu.rst
cuda._sanitizer.rst
cuda.rst Add API query for available per-process CUDA memory (#140620) 2024-12-03 00:24:03 +00:00
cuda.tunable.rst [ROCm] Fix TunableOp UTs: Rotating Buffer (#143172) 2024-12-14 06:18:11 +00:00
cuda_environment_variables.rst
cudnn_persistent_rnn.rst
cudnn_rnn_determinism.rst
data.rst
ddp_comm_hooks.rst
debugging_environment_variables.rst
deploy.rst
deterministic.rst
distributed.algorithms.join.rst
distributed.checkpoint.rst [DCP] Cross-link DCP doc to tutorials (#139776) 2024-11-07 02:19:49 +00:00
distributed.elastic.rst
distributed.fsdp.fully_shard.rst [FSDP2] Move to public torch.distributed.fsdp (#141868) 2024-12-07 01:24:28 +00:00
distributed.optim.rst
distributed.pipelining.rst [pipelining] Update tutorials and documentation (#143045) 2024-12-12 18:42:17 +00:00
distributed.rst [C10D] Update docs for wait() (#143305) 2024-12-17 00:41:11 +00:00
distributed.tensor.parallel.rst Update link in distributed.tensor.parallel.rst (#136103) 2024-09-15 19:36:29 +00:00
distributed.tensor.rst [dtensor][experimental] expose DTensor Context Parallel API (#137038) 2024-10-02 18:00:23 +00:00
distributions.rst
dlpack.rst
docutils.conf
export.ir_spec.rst [export] Update docs (#142011) 2024-12-05 03:44:46 +00:00
export.rst Add truediv support in export serializer (#136364) 2024-12-05 17:33:33 +00:00
fft.rst
fsdp.rst
func.api.rst
func.batch_norm.rst
func.migrating.rst
func.rst
func.ux_limitations.rst
func.whirlwind_tour.rst
future_mod.rst
futures.rst
fx.experimental.rst Add truediv support in export serializer (#136364) 2024-12-05 17:33:33 +00:00
fx.rst
hub.rst
index.rst Refactor "torch.mtia.memory_stats" API (#141723) 2024-12-09 19:19:19 +00:00
jit.rst
jit_builtin_functions.rst
jit_language_reference.rst
jit_language_reference_v2.rst
jit_python_reference.rst
jit_unsupported.rst
jit_utils.rst
library.rst Make torch.library.triton_op public (#141880) 2024-12-03 16:28:56 +00:00
linalg.rst
logging.rst
masked.rst
math-quantizer-equation.png
meta.rst
miscellaneous_environment_variables.rst Add environment variable to force no weights_only load (#138225) 2024-10-21 23:26:15 +00:00
mobile_optimizer.rst
model_zoo.rst
module_tracker.rst
monitor.rst
mps.rst
mps_environment_variables.rst
mtia.memory.rst Refactor "torch.mtia.memory_stats" API (#141723) 2024-12-09 19:19:19 +00:00
mtia.rst [MTIA] Support torch.mtia.empty_cache() (#141533) 2024-11-28 02:24:19 +00:00
multiprocessing.rst
name_inference.rst
named_tensor.rst
nested.rst
nn.attention.bias.rst
nn.attention.experimental.rst [Flex Attention] Paged Attention (#137164) 2024-10-29 17:05:22 +00:00
nn.attention.flex_attention.rst FlexAttention support for NJT (#136792) 2024-10-28 20:01:27 +00:00
nn.attention.rst [Flex Attention] Paged Attention (#137164) 2024-10-29 17:05:22 +00:00
nn.functional.rst
nn.init.rst
nn.rst Add APIs to separate norm calculation and gradient scaling in nn.utils.clip_grad_norm_ (#139662) 2024-11-07 23:13:23 +00:00
onnx.rst
onnx_dynamo.rst [ONNX] Describe memory usage of TorchDynamo-based exporter. (#139388) 2024-11-06 17:29:11 +00:00
onnx_dynamo_memory_usage.rst [ONNX] Describe memory usage of TorchDynamo-based exporter. (#139388) 2024-11-06 17:29:11 +00:00
onnx_dynamo_onnxruntime_backend.rst
onnx_torchscript.rst [ONNX] Remove deprecated export_to_pretty_string (#137790) 2024-10-21 18:17:48 +00:00
onnx_torchscript_supported_aten_ops.rst
optim.rst Ensure SWA boundary conditions w.r.t. definition (#133773) 2024-10-31 18:24:08 +00:00
package.rst
profiler.rst
quantization-accuracy-debugging.rst
quantization-backend-configuration.rst
quantization-support.rst
quantization.rst [BC-Breaking]Remove capture_pre_autograd_graph references in quantization (#139505) 2024-12-13 22:26:22 +00:00
random.rst
rpc.rst
signal.rst
size.rst
sparse.rst
special.rst
storage.rst Doc: Rewrite the storage.rst file to emphasize untyped storages (#140145) 2024-11-13 17:40:16 +00:00
tensor_attributes.rst [Docs] Remove duplicate declaration of double_tensor (#140927) 2024-11-18 21:22:30 +00:00
tensor_view.rst
tensorboard.rst
tensors.rst
testing.rst
threading_environment_variables.rst
torch.ao.ns._numeric_suite.rst
torch.ao.ns._numeric_suite_fx.rst
torch.compiler.config.rst Profile guided optimization for automatic_dynamic (#139001) 2024-11-03 06:29:57 +00:00
torch.compiler.rst Profile guided optimization for automatic_dynamic (#139001) 2024-11-03 06:29:57 +00:00
torch.compiler_aot_inductor.rst [aoti] Remove example inputs from aoti_compile_and_package (#140991) 2024-11-20 02:49:47 +00:00
torch.compiler_aot_inductor_minifier.rst Aoti minifier flatten (#141156) 2024-12-06 07:12:45 +00:00
torch.compiler_api.rst [dynamo] add torch.compiler.set_stance (#137504) 2024-10-16 16:18:25 +00:00
torch.compiler_best_practices_for_backends.rst
torch.compiler_cudagraph_trees.rst
torch.compiler_custom_backends.rst [pt2, docs] Add new PT2 troubleshooting doc (#138620) 2024-11-09 01:17:39 +00:00
torch.compiler_dynamic_shapes.rst
torch.compiler_dynamo_deepdive.rst fix typo in torch.compiler_dynamo_deepdive.rst (#140871) 2024-11-19 14:42:36 +00:00
torch.compiler_dynamo_overview.rst
torch.compiler_fake_tensor.rst [doc] improve code in fake tensor doc (#140329) 2024-11-13 05:14:56 +00:00
torch.compiler_faq.rst [pt2, docs] Add new PT2 troubleshooting doc (#138620) 2024-11-09 01:17:39 +00:00
torch.compiler_fine_grain_apis.rst
torch.compiler_get_started.rst [Inductor] Update AttrsDescriptor instantiation for Triton changes (#137458) 2024-10-14 20:20:29 +00:00
torch.compiler_inductor_profiling.rst
torch.compiler_ir.rst
torch.compiler_nn_module.rst
torch.compiler_performance_dashboard.rst
torch.compiler_profiling_torch_compile.rst [EZ] Fix spelling typo (#136157) 2024-09-16 19:30:30 +00:00
torch.compiler_transformations.rst
torch.compiler_troubleshooting.rst [pt2, docs] Add new PT2 troubleshooting doc (#138620) 2024-11-09 01:17:39 +00:00
torch.compiler_troubleshooting_old.rst [pt2, docs] Add new PT2 troubleshooting doc (#138620) 2024-11-09 01:17:39 +00:00
torch.overrides.rst
torch.rst Transform unbacked int expressions into a fresh unbacked int. (#141917) 2024-12-05 16:53:44 +00:00
torch_cuda_memory.rst
torch_environment_variables.rst
torch_nccl_environment_variables.rst
type_info.rst
utils.rst
xpu.rst Support torch.xpu.mem_get_info API (#141230) 2024-12-05 08:17:25 +00:00