mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Always use high precision for SDPA math backend (#128922)
Summary: feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts. Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16. Differential Revision: D58710805 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128922 Approved by: https://github.com/xw285cornell, https://github.com/drisspg
This commit is contained in:
parent
01cdcbf7c8
commit
c7cfa51721
4 changed files with 61 additions and 35 deletions
|
|
@ -73,8 +73,7 @@
|
|||
#endif
|
||||
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace at::native {
|
||||
|
||||
DEFINE_DISPATCH(_fused_sdp_choice_stub);
|
||||
|
||||
|
|
@ -734,26 +733,55 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
|||
value.is_contiguous(),
|
||||
"scaled_dot_product_attention: If inputs are nested tensors they must be contiguous");
|
||||
}
|
||||
auto attn_mask = attn_mask_;
|
||||
// Naive, composite implementation defined here.
|
||||
auto origin_dtype = query_.scalar_type();
|
||||
// Keep query, key, value in high precision for accuracy
|
||||
// NestedTensor reports issues for backward with autograd so disabled: must be
|
||||
// contiguous to get buffer.
|
||||
auto query_acc = (query_.scalar_type() == at::kHalf ||
|
||||
query_.scalar_type() == at::kBFloat16) &&
|
||||
!query_.is_nested()
|
||||
? query_.to(at::kFloat)
|
||||
: query_;
|
||||
auto key_acc =
|
||||
(key.scalar_type() == at::kHalf || key.scalar_type() == at::kBFloat16) &&
|
||||
!key.is_nested()
|
||||
? key.to(at::kFloat)
|
||||
: key;
|
||||
auto value_acc = (value.scalar_type() == at::kHalf ||
|
||||
value.scalar_type() == at::kBFloat16) &&
|
||||
!value.is_nested()
|
||||
? value.to(at::kFloat)
|
||||
: value;
|
||||
auto attn_mask = attn_mask_;
|
||||
// Naive, composite implementation defined here.
|
||||
|
||||
// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
|
||||
bool is_negative_scaling = scale.has_value() && scale.value() < 0.0;
|
||||
const auto scaling_factor = sdp::calculate_scale(query_, is_negative_scaling ? std::abs(scale.value()) : scale).sqrt();
|
||||
// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for
|
||||
// math
|
||||
bool is_negative_scaling = scale.has_value() && scale.value() < 0.0;
|
||||
const auto scaling_factor =
|
||||
sdp::calculate_scale(
|
||||
query_acc, is_negative_scaling ? std::abs(scale.value()) : scale)
|
||||
.sqrt();
|
||||
|
||||
const auto query = query_ * (is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor: scaling_factor);
|
||||
if (is_causal) {
|
||||
TORCH_CHECK(!attn_mask.has_value(),
|
||||
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
|
||||
TORCH_CHECK(!query.is_nested() && !key.is_nested(),
|
||||
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
|
||||
const auto query = query_acc *
|
||||
(is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor
|
||||
: scaling_factor);
|
||||
if (is_causal) {
|
||||
TORCH_CHECK(
|
||||
!attn_mask.has_value(),
|
||||
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
|
||||
TORCH_CHECK(
|
||||
!query.is_nested() && !key_acc.is_nested(),
|
||||
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");
|
||||
|
||||
// Replace attn_mask with causal mask; lower triangular elements take part in attention.
|
||||
const auto L = query.sym_size(-2), S = key.sym_size(-2);
|
||||
attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
|
||||
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
|
||||
// Replace attn_mask with causal mask; lower triangular elements take part
|
||||
// in attention.
|
||||
const auto L = query.sym_size(-2), S = key_acc.sym_size(-2);
|
||||
attn_mask =
|
||||
at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
|
||||
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
|
||||
}
|
||||
auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor);
|
||||
auto attn = at::matmul(query, key_acc.transpose(-2, -1) * scaling_factor);
|
||||
if (attn_mask.has_value()) {
|
||||
if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
|
||||
attn = attn.add(*attn_mask);
|
||||
|
|
@ -769,13 +797,13 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
|||
TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
|
||||
attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
|
||||
auto dropout_scaling = 1.0 / (1 - dropout_p);
|
||||
return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn);
|
||||
return std::make_tuple(at::matmul(attn, value_acc * dropout_scaling).to(origin_dtype), attn.to(origin_dtype));
|
||||
} else {
|
||||
attn = at::dropout(attn, dropout_p, true);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(at::matmul(attn, value), attn);
|
||||
return std::make_tuple(at::matmul(attn, value_acc).to(origin_dtype), attn.to(origin_dtype));
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
|
|
@ -998,6 +1026,4 @@ Tensor triton_multi_head_attention(
|
|||
#endif
|
||||
return proj;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -31,7 +31,6 @@ from torch.testing._internal.common_methods_invocations import (
|
|||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_utils import (
|
||||
is_iterable_of_tensors,
|
||||
IS_MACOS,
|
||||
run_tests,
|
||||
skipIfCrossRef,
|
||||
skipIfTorchDynamo,
|
||||
|
|
@ -1171,7 +1170,7 @@ class DecompOneOffTests(TestCase):
|
|||
[
|
||||
xfail(
|
||||
"nn.functional.scaled_dot_product_attention",
|
||||
dtypes=[torch.half] + ([torch.bfloat16] if IS_MACOS else []),
|
||||
dtypes=[torch.half],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2834,8 +2834,8 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
(out_ref, out_lp_ref, out),
|
||||
*zip(grads_ref, grads_ref_lp, grads),
|
||||
fudge_factors={
|
||||
'out': 2.0 ,
|
||||
'grad_query': 18.0 ,
|
||||
'out': 3.0 ,
|
||||
'grad_query': 150.0 ,
|
||||
'grad_key': 25.0,
|
||||
'grad_value': 8.5,
|
||||
}
|
||||
|
|
@ -2931,8 +2931,8 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
(out_ref, out_lp_ref, out),
|
||||
*zip(grads_ref, grads_ref_lp, grads),
|
||||
fudge_factors={
|
||||
"out": 1.75,
|
||||
"grad_query": 18.0,
|
||||
"out": 4,
|
||||
"grad_query": 150.0,
|
||||
"grad_key": 25.0,
|
||||
"grad_value": 8.0,
|
||||
"grad_attn_mask": 45.0,
|
||||
|
|
@ -3030,10 +3030,10 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
(out_ref, out_lp_ref, out),
|
||||
*zip(grads_ref, grads_ref_lp, grads),
|
||||
fudge_factors={
|
||||
'out': 1.5,
|
||||
'grad_query': 13.0,
|
||||
'grad_key': 2.0,
|
||||
'grad_value': 1.5,
|
||||
'out': 2.2,
|
||||
'grad_query': 160.0,
|
||||
'grad_key': 8.0,
|
||||
'grad_value': 4,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -3179,8 +3179,8 @@ class TestSDPACudaOnly(NNTestCase):
|
|||
*zip(grads_ref, grads_ref_lp, grads),
|
||||
fudge_factors={
|
||||
'out': 2.0,
|
||||
'grad_query': 12.0,
|
||||
'grad_key': 2.0,
|
||||
'grad_query': 100.0,
|
||||
'grad_key': 8.0,
|
||||
'grad_value': 2.0,
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5686,6 +5686,7 @@ Note:
|
|||
Due to the nature of fusing floating point operations, the output of this function may be different
|
||||
depending on what backend kernel is chosen.
|
||||
The c++ implementation supports torch.float64 and can be used when higher precision is required.
|
||||
For math backend, all intermediates are kept in torch.float if inputs are in torch.half or torch.bfloat16.
|
||||
For more information please see :doc:`/notes/numerical_accuracy`
|
||||
|
||||
Note:
|
||||
|
|
|
|||
Loading…
Reference in a new issue