Add option to configure reduced precision math backend for SDPA (#135964)

Summary: Address https://github.com/pytorch/pytorch/issues/135778 by adding a global flag to configure whether using high precision or low precision for math backend of SDPA.

Test Plan: buck2 run mode/opt //scripts/feikou/llm:run_attn_kernels

Differential Revision: D62625515

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135964
Approved by: https://github.com/jbschlosser
This commit is contained in:
Jianyu Huang 2024-09-24 07:11:36 +00:00 committed by PyTorch MergeBot
parent 44c871c34b
commit 0a35986cdb
10 changed files with 126 additions and 6 deletions

View file

@ -145,6 +145,14 @@ void Context::setSDPUseMath(bool e) {
enabled_mathSDP = e;
}
bool Context::allowFP16BF16ReductionMathSDP() const {
return allow_fp16_bf16_reduction_mathSDP;
}
void Context::setAllowFP16BF16ReductionMathSDP(bool e) {
allow_fp16_bf16_reduction_mathSDP = e;
}
bool Context::userEnabledCuDNNSDP() const {
return enabled_cudnnSDP;
}

View file

@ -234,6 +234,9 @@ class TORCH_API Context {
void setSDPUseCuDNN(bool);
bool userEnabledCuDNNSDP() const;
void setAllowFP16BF16ReductionMathSDP(bool);
bool allowFP16BF16ReductionMathSDP() const;
void setSDPUseOverrideable(bool);
bool userEnabledOverrideableSDP() const;
@ -390,6 +393,7 @@ class TORCH_API Context {
bool enabled_mathSDP = true;
bool enabled_cudnnSDP = true;
bool enabled_overrideable = true;
bool allow_fp16_bf16_reduction_mathSDP = false;
#ifdef USE_ROCM
bool benchmark_cudnn = true;
#else

View file

@ -804,22 +804,26 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
value.is_contiguous(),
"scaled_dot_product_attention: If inputs are nested tensors they must be contiguous");
}
auto& ctx = at::globalContext();
auto origin_dtype = query_.scalar_type();
// Keep query, key, value in high precision for accuracy
// NestedTensor reports issues for backward with autograd so disabled: must be
// contiguous to get buffer.
auto query_acc = (query_.scalar_type() == at::kHalf ||
query_.scalar_type() == at::kBFloat16) &&
auto query_acc = !ctx.allowFP16BF16ReductionMathSDP() &&
(query_.scalar_type() == at::kHalf ||
query_.scalar_type() == at::kBFloat16) &&
!query_.is_nested()
? query_.to(at::kFloat)
: query_;
auto key_acc =
(key.scalar_type() == at::kHalf || key.scalar_type() == at::kBFloat16) &&
auto key_acc = !ctx.allowFP16BF16ReductionMathSDP() &&
(key.scalar_type() == at::kHalf ||
key.scalar_type() == at::kBFloat16) &&
!key.is_nested()
? key.to(at::kFloat)
: key;
auto value_acc = (value.scalar_type() == at::kHalf ||
value.scalar_type() == at::kBFloat16) &&
auto value_acc = !ctx.allowFP16BF16ReductionMathSDP() &&
(value.scalar_type() == at::kHalf ||
value.scalar_type() == at::kBFloat16) &&
!value.is_nested()
? value.to(at::kFloat)
: value;

View file

@ -85,6 +85,10 @@ torch.backends.cuda
.. autofunction:: torch.backends.cuda.enable_math_sdp
.. autofunction:: torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed
.. autofunction:: torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp
.. autofunction:: torch.backends.cuda.cudnn_sdp_enabled
.. autofunction:: torch.backends.cuda.enable_cudnn_sdp

View file

@ -110,6 +110,13 @@ reduced-precision reductions are problematic, they can be turned off with
For more information see :ref:`allow_fp16_reduced_precision_reduction<fp16reducedprecision>` and :ref:`allow_bf16_reduced_precision_reduction<bf16reducedprecision>`
Reduced Precision Reduction for FP16 and BF16 in Scaled Dot Product Attention (SDPA)
------------------------------------------------------------------------------------
A naive SDPA math backend, when using FP16/BF16 inputs, can accumulate significant numerical errors due to the usage of low-precision intermediate buffers. To mitigate this issue, the default behavior now involves upcasting FP16/BF16 inputs to FP32. Computations are performed in FP32/TF32, and the final FP32 results are then downcasted back to FP16/BF16. This will improve numerical accuracy of the final output for the math backend with FP16/BF16 inputs, but increases memory usages and may cause the performance regressions in the math backend as computations shift from FP16/BF16 BMM to FP32/TF32 BMM/Matmul.
For scenarios where reduced-precision reductions are preferred for speed, they can be enabled with the following setting:
``torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)``
.. _fp16_on_mi200:
Reduced Precision FP16 and BF16 GEMMs and Convolutions on AMD Instinct MI200 devices

View file

@ -849,6 +849,44 @@ class TestTransformers(NNTestCase):
self.assertEqual(masked_output, is_causal_output)
@onlyCUDA
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt pre-SM80 hardware"
)
def test_math_backend_high_precision(self):
xq = torch.rand([1, 128, 2, 80], device="cuda", dtype=torch.bfloat16) * 5
xk = torch.rand([1, 128, 2, 80], device="cuda", dtype=torch.bfloat16) * 5
xv = torch.randn([1, 128, 2, 80], device="cuda", dtype=torch.bfloat16)
mask = None
def scaled_dot_product_attention(
xq: torch.Tensor, xk: torch.Tensor, xv: torch.Tensor, mask: Optional[torch.Tensor], backend: SDPBackend
) -> torch.Tensor:
n_rep = 1
xq, xk, xv = (tensor.transpose(1, 2) for tensor in (xq, xk, xv))
xk = xk.repeat_interleave(n_rep, dim=1)
xv = xv.repeat_interleave(n_rep, dim=1)
with sdpa_kernel(backends=[backend]):
attn_output = F.scaled_dot_product_attention(
xq, xk, xv, attn_mask=mask, dropout_p=0.0
)
return attn_output.transpose(1, 2)
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
sdp_math_low_prec_out = scaled_dot_product_attention(xq, xk, xv, mask, SDPBackend.MATH)
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(False)
sdp_math_high_prec_out = scaled_dot_product_attention(xq, xk, xv, mask, SDPBackend.MATH)
sdp_math_fp64_out_ref = scaled_dot_product_attention(
xq.double(), xk.double(), xv.double(), mask, SDPBackend.MATH
).bfloat16()
torch.testing.assert_close(sdp_math_high_prec_out, sdp_math_fp64_out_ref, atol=1e-2, rtol=1e-2)
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close"):
torch.testing.assert_close(sdp_math_low_prec_out, sdp_math_fp64_out_ref, atol=1e-2, rtol=1e-2)
@onlyCUDA
@parametrize("nb_heads", [1, 8])
@parametrize("bias", [True, False])

View file

@ -1156,6 +1156,8 @@ def _set_sdp_use_mem_efficient(
) -> None: ... # THPModule_setSDPUseMemEfficient
def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath
def _get_math_sdp_allow_fp16_bf16_reduction() -> _bool: ... # THPModule_allowFP16BF16ReductionMathSDP
def _set_math_sdp_allow_fp16_bf16_reduction(arg: _bool) -> None: ... # THPModule_setAllowFP16BF16ReductionMathSDP
def _get_overrideable_sdp_enabled() -> _bool: ... # THPModule_userEnabledOverrideableSDP
def _set_sdp_use_overrideable(arg: _bool) -> None: ... # THPModule_setSDPUseOverrideable
def _get_cudnn_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP

View file

@ -610,6 +610,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._get_graph_executor_optimize",
"torch._C._get_linalg_preferred_backend",
"torch._C._get_math_sdp_enabled",
"torch._C._get_math_sdp_allow_fp16_bf16_reduction",
"torch._C._get_max_operator_version",
"torch._C._get_mem_efficient_sdp_enabled",
"torch._C._get_mkldnn_enabled",
@ -1145,6 +1146,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._set_qengine",
"torch._C._set_sdp_use_flash",
"torch._C._set_sdp_use_math",
"torch._C._set_math_sdp_allow_fp16_bf16_reduction",
"torch._C._set_sdp_use_mem_efficient",
"torch._C._set_should_use_format_with_string_table",
"torch._C._set_storage_access_error_msg",
@ -2398,11 +2400,13 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.backends.cuda.can_use_cudnn_attention",
"torch.backends.cuda.enable_flash_sdp",
"torch.backends.cuda.enable_math_sdp",
"torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp",
"torch.backends.cuda.enable_mem_efficient_sdp",
"torch.backends.cuda.flash_sdp_enabled",
"torch.backends.cuda.is_built",
"torch.backends.cuda.is_flash_attention_available",
"torch.backends.cuda.math_sdp_enabled",
"torch.backends.cuda.fp16_bf16_reduction_math_sdp_allowed",
"torch.backends.cuda.mem_efficient_sdp_enabled",
"torch.backends.cuda.cudnn_sdp_enabled",
"torch.backends.cuda.enable_cudnn_sdp",

View file

@ -25,6 +25,8 @@ __all__ = [
"mem_efficient_sdp_enabled",
"math_sdp_enabled",
"enable_math_sdp",
"allow_fp16_bf16_reduction_math_sdp",
"fp16_bf16_reduction_math_sdp_allowed",
"is_flash_attention_available",
"can_use_flash_attention",
"can_use_efficient_attention",
@ -322,6 +324,24 @@ def enable_math_sdp(enabled: bool):
torch._C._set_sdp_use_math(enabled)
def allow_fp16_bf16_reduction_math_sdp(enabled: bool):
r"""
.. warning:: This flag is beta and subject to change.
Enables or disables fp16/bf16 reduction in math scaled dot product attention.
"""
torch._C._set_math_sdp_allow_fp16_bf16_reduction(enabled)
def fp16_bf16_reduction_math_sdp_allowed():
r"""
.. warning:: This flag is beta and subject to change.
Returns whether fp16/bf16 reduction in math scaled dot product attention is enabled or not.
"""
return torch._C._get_math_sdp_allow_fp16_bf16_reduction()
def is_flash_attention_available() -> bool:
r"""Check if PyTorch was built with FlashAttention for scaled_dot_product_attention.

View file

@ -738,6 +738,27 @@ PyObject* THPModule_userEnabledMathSDP(PyObject* _unused, PyObject* noargs) {
else
Py_RETURN_FALSE;
}
PyObject* THPModule_setAllowFP16BF16ReductionMathSDP(
PyObject* _unused,
PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
PyBool_Check(arg),
"set_sdp_use_math expects a bool, "
"but got ",
THPUtils_typename(arg));
at::globalContext().setAllowFP16BF16ReductionMathSDP(arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THPModule_allowFP16BF16ReductionMathSDP(
PyObject* _unused,
PyObject* noargs) {
if (at::globalContext().allowFP16BF16ReductionMathSDP())
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
}
PyObject* THPModule_setSDPUseOverrideable(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(
@ -1362,6 +1383,14 @@ static PyMethodDef TorchMethods[] = { // NOLINT
METH_NOARGS,
nullptr},
{"_set_sdp_use_math", THPModule_setSDPUseMath, METH_O, nullptr},
{"_get_math_sdp_allow_fp16_bf16_reduction",
THPModule_allowFP16BF16ReductionMathSDP,
METH_NOARGS,
nullptr},
{"_set_math_sdp_allow_fp16_bf16_reduction",
THPModule_setAllowFP16BF16ReductionMathSDP,
METH_O,
nullptr},
{"_get_overrideable_sdp_enabled",
THPModule_userEnabledOverrideableSDP,
METH_NOARGS,