mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add option to configure reduced precision math backend for SDPA (#135964)
Summary: Address https://github.com/pytorch/pytorch/issues/135778 by adding a global flag to configure whether using high precision or low precision for math backend of SDPA. Test Plan: buck2 run mode/opt //scripts/feikou/llm:run_attn_kernels Differential Revision: D62625515 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135964 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
44c871c34b
commit
0a35986cdb
10 changed files with 126 additions and 6 deletions
|
|
@ -145,6 +145,14 @@ void Context::setSDPUseMath(bool e) {
|
|||
enabled_mathSDP = e;
|
||||
}
|
||||
|
||||
bool Context::allowFP16BF16ReductionMathSDP() const {
|
||||
return allow_fp16_bf16_reduction_mathSDP;
|
||||
}
|
||||
|
||||
void Context::setAllowFP16BF16ReductionMathSDP(bool e) {
|
||||
allow_fp16_bf16_reduction_mathSDP = e;
|
||||
}
|
||||
|
||||
bool Context::userEnabledCuDNNSDP() const {
|
||||
return enabled_cudnnSDP;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -234,6 +234,9 @@ class TORCH_API Context {
|
|||
void setSDPUseCuDNN(bool);
|
||||
bool userEnabledCuDNNSDP() const;
|
||||
|
||||
void setAllowFP16BF16ReductionMathSDP(bool);
|
||||
bool allowFP16BF16ReductionMathSDP() const;
|
||||
|
||||
void setSDPUseOverrideable(bool);
|
||||
bool userEnabledOverrideableSDP() const;
|
||||
|
||||
|
|
@ -390,6 +393,7 @@ class TORCH_API Context {
|
|||
bool enabled_mathSDP = true;
|
||||
bool enabled_cudnnSDP = true;
|
||||
bool enabled_overrideable = true;
|
||||
bool allow_fp16_bf16_reduction_mathSDP = false;
|
||||
#ifdef USE_ROCM
|
||||
bool benchmark_cudnn = true;
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -804,22 +804,26 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
|||
value.is_contiguous(),
|
||||
"scaled_dot_product_attention: If inputs are nested tensors they must be contiguous");
|
||||
}
|
||||
auto& ctx = at::globalContext();
|
||||
auto origin_dtype = query_.scalar_type();
|
||||
// Keep query, key, value in high precision for accuracy
|
||||
// NestedTensor reports issues for backward with autograd so disabled: must be
|
||||
// contiguous to get buffer.
|
||||
auto query_acc = (query_.scalar_type() == at::kHalf ||
|
||||
query_.scalar_type() == at::kBFloat16) &&
|
||||
auto query_acc = !ctx.allowFP16BF16ReductionMathSDP() &&
|
||||
(query_.scalar_type() == at::kHalf ||
|
||||
query_.scalar_type() == at::kBFloat16) &&
|
||||
!query_.is_nested()
|
||||
? query_.to(at::kFloat)
|
||||
: query_;
|
||||
auto key_acc =
|
||||
(key.scalar_type() == at::kHalf || key.scalar_type() == at::kBFloat16) &&
|
||||
auto key_acc = !ctx.allowFP16BF16ReductionMathSDP() &&
|
||||
(key.scalar_type() == at::kHalf ||
|
||||
key.scalar_type() == at::kBFloat16) &&
|
||||
!key.is_nested()
|
||||
? key.to(at::kFloat)
|
||||
: key;
|
||||
auto value_acc = (value.scalar_type() == at::kHalf ||
|
||||
value.scalar_type() == at::kBFloat16) &&
|
||||
auto value_acc = !ctx.allowFP16BF16ReductionMathSDP() &&
|
||||
(value.scalar_type() == at::kHalf ||
|
||||
value.scalar_type() == at::kBFloat16) &&
|
||||
!value.is_nested()
|
||||
? value.to(at::kFloat)
|
||||
: value;
|
||||
|
|
|
|||
|
|
@ -85,6 +85,10 @@ torch.backends.cuda
|
|||
|
||||
.. autofunction:: torch.backends.cuda.enable_math_sdp
|
||||
|
||||
.. autofunction:: torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed
|
||||
|
||||
.. autofunction:: torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp
|
||||
|
||||
.. autofunction:: torch.backends.cuda.cudnn_sdp_enabled
|
||||
|
||||
.. autofunction:: torch.backends.cuda.enable_cudnn_sdp
|
||||
|
|
|
|||
|
|
@ -110,6 +110,13 @@ reduced-precision reductions are problematic, they can be turned off with
|
|||
|
||||
For more information see :ref:`allow_fp16_reduced_precision_reduction<fp16reducedprecision>` and :ref:`allow_bf16_reduced_precision_reduction<bf16reducedprecision>`
|
||||
|
||||
Reduced Precision Reduction for FP16 and BF16 in Scaled Dot Product Attention (SDPA)
|
||||
------------------------------------------------------------------------------------
|
||||
A naive SDPA math backend, when using FP16/BF16 inputs, can accumulate significant numerical errors due to the usage of low-precision intermediate buffers. To mitigate this issue, the default behavior now involves upcasting FP16/BF16 inputs to FP32. Computations are performed in FP32/TF32, and the final FP32 results are then downcasted back to FP16/BF16. This will improve numerical accuracy of the final output for the math backend with FP16/BF16 inputs, but increases memory usages and may cause the performance regressions in the math backend as computations shift from FP16/BF16 BMM to FP32/TF32 BMM/Matmul.
|
||||
|
||||
For scenarios where reduced-precision reductions are preferred for speed, they can be enabled with the following setting:
|
||||
``torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)``
|
||||
|
||||
.. _fp16_on_mi200:
|
||||
|
||||
Reduced Precision FP16 and BF16 GEMMs and Convolutions on AMD Instinct MI200 devices
|
||||
|
|
|
|||
|
|
@ -849,6 +849,44 @@ class TestTransformers(NNTestCase):
|
|||
|
||||
self.assertEqual(masked_output, is_causal_output)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt pre-SM80 hardware"
|
||||
)
|
||||
def test_math_backend_high_precision(self):
|
||||
xq = torch.rand([1, 128, 2, 80], device="cuda", dtype=torch.bfloat16) * 5
|
||||
xk = torch.rand([1, 128, 2, 80], device="cuda", dtype=torch.bfloat16) * 5
|
||||
xv = torch.randn([1, 128, 2, 80], device="cuda", dtype=torch.bfloat16)
|
||||
mask = None
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, mask: Optional[torch.Tensor], backend: SDPBackend
|
||||
) -> torch.Tensor:
|
||||
n_rep = 1
|
||||
xq, xk, xv = (tensor.transpose(1, 2) for tensor in (xq, xk, xv))
|
||||
xk = xk.repeat_interleave(n_rep, dim=1)
|
||||
xv = xv.repeat_interleave(n_rep, dim=1)
|
||||
|
||||
with sdpa_kernel(backends=[backend]):
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
xq, xk, xv, attn_mask=mask, dropout_p=0.0
|
||||
)
|
||||
return attn_output.transpose(1, 2)
|
||||
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
sdp_math_low_prec_out = scaled_dot_product_attention(xq, xk, xv, mask, SDPBackend.MATH)
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(False)
|
||||
sdp_math_high_prec_out = scaled_dot_product_attention(xq, xk, xv, mask, SDPBackend.MATH)
|
||||
|
||||
sdp_math_fp64_out_ref = scaled_dot_product_attention(
|
||||
xq.double(), xk.double(), xv.double(), mask, SDPBackend.MATH
|
||||
).bfloat16()
|
||||
|
||||
torch.testing.assert_close(sdp_math_high_prec_out, sdp_math_fp64_out_ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close"):
|
||||
torch.testing.assert_close(sdp_math_low_prec_out, sdp_math_fp64_out_ref, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@onlyCUDA
|
||||
@parametrize("nb_heads", [1, 8])
|
||||
@parametrize("bias", [True, False])
|
||||
|
|
|
|||
|
|
@ -1156,6 +1156,8 @@ def _set_sdp_use_mem_efficient(
|
|||
) -> None: ... # THPModule_setSDPUseMemEfficient
|
||||
def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
|
||||
def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath
|
||||
def _get_math_sdp_allow_fp16_bf16_reduction() -> _bool: ... # THPModule_allowFP16BF16ReductionMathSDP
|
||||
def _set_math_sdp_allow_fp16_bf16_reduction(arg: _bool) -> None: ... # THPModule_setAllowFP16BF16ReductionMathSDP
|
||||
def _get_overrideable_sdp_enabled() -> _bool: ... # THPModule_userEnabledOverrideableSDP
|
||||
def _set_sdp_use_overrideable(arg: _bool) -> None: ... # THPModule_setSDPUseOverrideable
|
||||
def _get_cudnn_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
|
||||
|
|
|
|||
|
|
@ -610,6 +610,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch._C._get_graph_executor_optimize",
|
||||
"torch._C._get_linalg_preferred_backend",
|
||||
"torch._C._get_math_sdp_enabled",
|
||||
"torch._C._get_math_sdp_allow_fp16_bf16_reduction",
|
||||
"torch._C._get_max_operator_version",
|
||||
"torch._C._get_mem_efficient_sdp_enabled",
|
||||
"torch._C._get_mkldnn_enabled",
|
||||
|
|
@ -1145,6 +1146,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch._C._set_qengine",
|
||||
"torch._C._set_sdp_use_flash",
|
||||
"torch._C._set_sdp_use_math",
|
||||
"torch._C._set_math_sdp_allow_fp16_bf16_reduction",
|
||||
"torch._C._set_sdp_use_mem_efficient",
|
||||
"torch._C._set_should_use_format_with_string_table",
|
||||
"torch._C._set_storage_access_error_msg",
|
||||
|
|
@ -2398,11 +2400,13 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch.backends.cuda.can_use_cudnn_attention",
|
||||
"torch.backends.cuda.enable_flash_sdp",
|
||||
"torch.backends.cuda.enable_math_sdp",
|
||||
"torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp",
|
||||
"torch.backends.cuda.enable_mem_efficient_sdp",
|
||||
"torch.backends.cuda.flash_sdp_enabled",
|
||||
"torch.backends.cuda.is_built",
|
||||
"torch.backends.cuda.is_flash_attention_available",
|
||||
"torch.backends.cuda.math_sdp_enabled",
|
||||
"torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed",
|
||||
"torch.backends.cuda.mem_efficient_sdp_enabled",
|
||||
"torch.backends.cuda.cudnn_sdp_enabled",
|
||||
"torch.backends.cuda.enable_cudnn_sdp",
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ __all__ = [
|
|||
"mem_efficient_sdp_enabled",
|
||||
"math_sdp_enabled",
|
||||
"enable_math_sdp",
|
||||
"allow_fp16_bf16_reduction_math_sdp",
|
||||
"fp16_bf16_reduction_math_sdp_allowed",
|
||||
"is_flash_attention_available",
|
||||
"can_use_flash_attention",
|
||||
"can_use_efficient_attention",
|
||||
|
|
@ -322,6 +324,24 @@ def enable_math_sdp(enabled: bool):
|
|||
torch._C._set_sdp_use_math(enabled)
|
||||
|
||||
|
||||
def allow_fp16_bf16_reduction_math_sdp(enabled: bool):
|
||||
r"""
|
||||
.. warning:: This flag is beta and subject to change.
|
||||
|
||||
Enables or disables fp16/bf16 reduction in math scaled dot product attention.
|
||||
"""
|
||||
torch._C._set_math_sdp_allow_fp16_bf16_reduction(enabled)
|
||||
|
||||
|
||||
def fp16_bf16_reduction_math_sdp_allowed():
|
||||
r"""
|
||||
.. warning:: This flag is beta and subject to change.
|
||||
|
||||
Returns whether fp16/bf16 reduction in math scaled dot product attention is enabled or not.
|
||||
"""
|
||||
return torch._C._get_math_sdp_allow_fp16_bf16_reduction()
|
||||
|
||||
|
||||
def is_flash_attention_available() -> bool:
|
||||
r"""Check if PyTorch was built with FlashAttention for scaled_dot_product_attention.
|
||||
|
||||
|
|
|
|||
|
|
@ -738,6 +738,27 @@ PyObject* THPModule_userEnabledMathSDP(PyObject* _unused, PyObject* noargs) {
|
|||
else
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
PyObject* THPModule_setAllowFP16BF16ReductionMathSDP(
|
||||
PyObject* _unused,
|
||||
PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
PyBool_Check(arg),
|
||||
"set_sdp_use_math expects a bool, "
|
||||
"but got ",
|
||||
THPUtils_typename(arg));
|
||||
at::globalContext().setAllowFP16BF16ReductionMathSDP(arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
PyObject* THPModule_allowFP16BF16ReductionMathSDP(
|
||||
PyObject* _unused,
|
||||
PyObject* noargs) {
|
||||
if (at::globalContext().allowFP16BF16ReductionMathSDP())
|
||||
Py_RETURN_TRUE;
|
||||
else
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
PyObject* THPModule_setSDPUseOverrideable(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(
|
||||
|
|
@ -1362,6 +1383,14 @@ static PyMethodDef TorchMethods[] = { // NOLINT
|
|||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_sdp_use_math", THPModule_setSDPUseMath, METH_O, nullptr},
|
||||
{"_get_math_sdp_allow_fp16_bf16_reduction",
|
||||
THPModule_allowFP16BF16ReductionMathSDP,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_math_sdp_allow_fp16_bf16_reduction",
|
||||
THPModule_setAllowFP16BF16ReductionMathSDP,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_get_overrideable_sdp_enabled",
|
||||
THPModule_userEnabledOverrideableSDP,
|
||||
METH_NOARGS,
|
||||
|
|
|
|||
Loading…
Reference in a new issue