From c7cfa5172139737bf75afbd4a7920b1a02b1dcb2 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Sun, 4 Aug 2024 23:58:14 +0000 Subject: [PATCH] Always use high precision for SDPA math backend (#128922) Summary: feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts. Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16. Differential Revision: D58710805 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922 Approved by: https://github.com/xw285cornell, https://github.com/drisspg --- .../ATen/native/transformers/attention.cpp | 72 +++++++++++++------ test/test_decomp.py | 3 +- test/test_transformers.py | 20 +++--- torch/nn/functional.py | 1 + 4 files changed, 61 insertions(+), 35 deletions(-) diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index f55683ecb26..576b92b0e56 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -73,8 +73,7 @@ #endif #include -namespace at { -namespace native { +namespace at::native { DEFINE_DISPATCH(_fused_sdp_choice_stub); @@ -734,26 +733,55 @@ std::tuple _scaled_dot_product_attention_math( value.is_contiguous(), "scaled_dot_product_attention: If inputs are nested tensors they must be contiguous"); } - auto attn_mask = attn_mask_; - // Naive, composite implementation defined here. + auto origin_dtype = query_.scalar_type(); + // Keep query, key, value in high precision for accuracy + // NestedTensor reports issues for backward with autograd so disabled: must be + // contiguous to get buffer. + auto query_acc = (query_.scalar_type() == at::kHalf || + query_.scalar_type() == at::kBFloat16) && + !query_.is_nested() + ? query_.to(at::kFloat) + : query_; + auto key_acc = + (key.scalar_type() == at::kHalf || key.scalar_type() == at::kBFloat16) && + !key.is_nested() + ? key.to(at::kFloat) + : key; + auto value_acc = (value.scalar_type() == at::kHalf || + value.scalar_type() == at::kBFloat16) && + !value.is_nested() + ? value.to(at::kFloat) + : value; + auto attn_mask = attn_mask_; + // Naive, composite implementation defined here. - // Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math - bool is_negative_scaling = scale.has_value() && scale.value() < 0.0; - const auto scaling_factor = sdp::calculate_scale(query_, is_negative_scaling ? std::abs(scale.value()) : scale).sqrt(); + // Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for + // math + bool is_negative_scaling = scale.has_value() && scale.value() < 0.0; + const auto scaling_factor = + sdp::calculate_scale( + query_acc, is_negative_scaling ? std::abs(scale.value()) : scale) + .sqrt(); - const auto query = query_ * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor); - if (is_causal) { - TORCH_CHECK(!attn_mask.has_value(), - "_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True"); - TORCH_CHECK(!query.is_nested() && !key.is_nested(), - "_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True"); + const auto query = query_acc * + (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor + : scaling_factor); + if (is_causal) { + TORCH_CHECK( + !attn_mask.has_value(), + "_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True"); + TORCH_CHECK( + !query.is_nested() && !key_acc.is_nested(), + "_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True"); - // Replace attn_mask with causal mask; lower triangular elements take part in attention. - const auto L = query.sym_size(-2), S = key.sym_size(-2); - attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril(); - attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype()); + // Replace attn_mask with causal mask; lower triangular elements take part + // in attention. + const auto L = query.sym_size(-2), S = key_acc.sym_size(-2); + attn_mask = + at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril(); + attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype()); } - auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor); + auto attn = at::matmul(query, key_acc.transpose(-2, -1) * scaling_factor); if (attn_mask.has_value()) { if (at::areAnyTensorSubclassLike({attn, *attn_mask})) { attn = attn.add(*attn_mask); @@ -769,13 +797,13 @@ std::tuple _scaled_dot_product_attention_math( TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes."); attn = attn.masked_fill(dropout_mask->logical_not(), 0.0); auto dropout_scaling = 1.0 / (1 - dropout_p); - return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn); + return std::make_tuple(at::matmul(attn, value_acc * dropout_scaling).to(origin_dtype), attn.to(origin_dtype)); } else { attn = at::dropout(attn, dropout_p, true); } } - return std::make_tuple(at::matmul(attn, value), attn); + return std::make_tuple(at::matmul(attn, value_acc).to(origin_dtype), attn.to(origin_dtype)); } std::tuple @@ -998,6 +1026,4 @@ Tensor triton_multi_head_attention( #endif return proj; } - -} // namespace native -} // namespace at +} // namespace at::native diff --git a/test/test_decomp.py b/test/test_decomp.py index c76ad2e2f42..3c90c0f16b7 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -31,7 +31,6 @@ from torch.testing._internal.common_methods_invocations import ( from torch.testing._internal.common_modules import module_db, modules from torch.testing._internal.common_utils import ( is_iterable_of_tensors, - IS_MACOS, run_tests, skipIfCrossRef, skipIfTorchDynamo, @@ -1171,7 +1170,7 @@ class DecompOneOffTests(TestCase): [ xfail( "nn.functional.scaled_dot_product_attention", - dtypes=[torch.half] + ([torch.bfloat16] if IS_MACOS else []), + dtypes=[torch.half], ), ], ) diff --git a/test/test_transformers.py b/test/test_transformers.py index af457828ebe..b3b1e1419af 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2834,8 +2834,8 @@ class TestSDPACudaOnly(NNTestCase): (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), fudge_factors={ - 'out': 2.0 , - 'grad_query': 18.0 , + 'out': 3.0 , + 'grad_query': 150.0 , 'grad_key': 25.0, 'grad_value': 8.5, } @@ -2931,8 +2931,8 @@ class TestSDPACudaOnly(NNTestCase): (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), fudge_factors={ - "out": 1.75, - "grad_query": 18.0, + "out": 4, + "grad_query": 150.0, "grad_key": 25.0, "grad_value": 8.0, "grad_attn_mask": 45.0, @@ -3030,10 +3030,10 @@ class TestSDPACudaOnly(NNTestCase): (out_ref, out_lp_ref, out), *zip(grads_ref, grads_ref_lp, grads), fudge_factors={ - 'out': 1.5, - 'grad_query': 13.0, - 'grad_key': 2.0, - 'grad_value': 1.5, + 'out': 2.2, + 'grad_query': 160.0, + 'grad_key': 8.0, + 'grad_value': 4, } ) @@ -3179,8 +3179,8 @@ class TestSDPACudaOnly(NNTestCase): *zip(grads_ref, grads_ref_lp, grads), fudge_factors={ 'out': 2.0, - 'grad_query': 12.0, - 'grad_key': 2.0, + 'grad_query': 100.0, + 'grad_key': 8.0, 'grad_value': 2.0, } ) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index af8710e00a0..67c3bfc85de 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -5686,6 +5686,7 @@ Note: Due to the nature of fusing floating point operations, the output of this function may be different depending on what backend kernel is chosen. The c++ implementation supports torch.float64 and can be used when higher precision is required. + For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16. For more information please see :doc:`/notes/numerical_accuracy` Note: