diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index ce51a37e66b..a5f01419368 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -710,6 +710,33 @@ std::tuple _scaled_dot_product_attention( query_, key, value, attn_mask_, dropout_p, is_causal); } +inline void validate_sdpa_input( + const Tensor& query_, + const Tensor& key, + const Tensor& value, + const c10::optional& attn_mask_, + double dropout_p, + bool is_causal) { + TORCH_CHECK( + query_.dtype() == key.dtype() && query_.dtype() == value.dtype(), + "Expected query, key, and value to have the same dtype, but got query.dtype: ", + query_.dtype(), " key.dtype: ", key.dtype(), " and value.dtype: ", value.dtype(), " instead."); + TORCH_CHECK( + query_.device() == key.device() && query_.device() == value.device(), + "Expected query, key, and value to have the same device type, but got query.device: ", + query_.device(), " key.device: ", key.device(), " and value.device: ", value.device(), " instead."); + TORCH_CHECK( + query_.dim() >= 2 && key.dim() >= 2 && value.dim() >= 2, + "Expected query, key, and value to all be at least 2 dimensional, but got query.dim: ", + query_.dim(), " key.dim: ", key.dim(), " and value.dim: ", value.dim(), " instead."); + if (attn_mask_.has_value()){ + auto mask_dtype = attn_mask_->dtype(); + TORCH_CHECK(mask_dtype == at::kBool || mask_dtype == query_.dtype(), + "Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: ", + mask_dtype, " and query.dtype: ", query_.dtype(), " instead."); + } + return; +} // Computes scaled dot product attention on query, key and value tensors, using // an optional attention mask if passed, and applying dropout if a probability // greater than 0.0 is specified. @@ -745,6 +772,7 @@ Tensor scaled_dot_product_attention( const c10::optional& attn_mask_, double dropout_p, bool is_causal) { + validate_sdpa_input(query_, key, value, attn_mask_, dropout_p, is_causal); int64_t choice_int = static_cast(sdp::SDPBackend::math); if (query_.device().type() == DeviceType::CUDA){ choice_int = _fused_sdp_choice_stub(query_.device().type(), diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index d0f03ebca91..14ea9875c79 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -214,22 +214,51 @@ inline bool check_tensor_shapes(sdp_params params, bool debug) { return true; } +inline bool check_equal_batch_size_and_num_heads(sdp_params params, bool debug) { + // This is expected to be called after check_tensor_shapes ensuring that the size() + // calls won't error since the inputs are all 4 dimensional + bool same_batch_size = params.query.size(0) == params.key.size(0) && + params.query.size(0) == params.value.size(0); + // We pass through for NestedTensors since this is checked in a later filter + bool same_num_heads = params.query.is_nested() + ? true + : params.query.size(1) == params.key.size(1) && + params.query.size(1) == params.value.size(1); + + if (!(same_batch_size && same_num_heads)) { + if (debug) { + TORCH_WARN( + "Both fused kernels requires query, key and value to have the same batch_size and num_heads. Query.sizes(): ", + params.query.sizes(), + ", Key sizes(): ", + params.key.sizes(), + ", Value sizes(): ", + params.value.sizes(), + " instead."); + } + return false; + } + return true; +} + inline bool check_head_dim_size(sdp_params params, bool debug) { const int64_t query_size_last = params.query.size(-1); + const int64_t key_size_last = params.key.size(-1); const int64_t value_size_last = params.value.size(-1); - if (!(query_size_last == params.key.size(-1) && query_size_last % 8 == 0 && + if (!(query_size_last == key_size_last && + query_size_last == value_size_last && query_size_last % 8 == 0 && query_size_last <= 128 && value_size_last % 8 == 0 && value_size_last <= 128)) { if (debug) { TORCH_WARN( - "Flash attention requires last dimension of inputs to be a multiple of 8 and less than or equal to 128.", - "Got Query.size(-1): ", - query_size_last, - ", Key.size(-1): ", - params.key.size(-1), - ", Value.size(-1): ", - params.value.size(-1), - " instead."); + "Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 128.", + " Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + params.key.size(-1), + ", Value.size(-1): ", + params.value.size(-1), + " instead."); } return false; } @@ -393,9 +422,10 @@ inline bool use_flash_attention(sdp_params params, bool debug) { return false; #endif // Define gate functions that determine if a flash kernel can be ran - constexpr std::array constraints {{ + constexpr std::array constraints {{ check_runtime_disabled_flash, check_tensor_shapes, + check_equal_batch_size_and_num_heads, check_for_attn_mask, check_head_dim_size, check_gpu_sm75_or_greater, @@ -427,11 +457,12 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) { at::kHalf, at::kFloat, at::kBFloat16}; // Define gate functions that determine if a flash kernel can be ran - constexpr std::array constraints{{ + constexpr std::array constraints{{ check_gpu_sm50_or_greater, check_runtime_disabled_mem_efficient, check_requires_grad_and_nested, check_tensor_shapes, + check_equal_batch_size_and_num_heads, check_for_attn_mask, check_head_dim_size_mem_efficient, check_gpu_sm86_head_dim_128, diff --git a/test/test_transformers.py b/test/test_transformers.py index 740faf4c460..3a85be95cac 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1076,6 +1076,13 @@ class TestSDPA(NNTestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True + backend_map = { + SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False}, + SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False}, + SDPBackend.EFFICIENT_ATTENTION: { + "enable_math": False, "enable_flash": False, "enable_mem_efficient": True} + } + def rand_tensor(self, shape: Tuple[int], device: str, dtype: torch.dtype, type: str, requires_grad: bool = False, packed: bool = False) -> torch.Tensor: """Creates rand dense or nested tensor with given shape and type. @@ -1480,22 +1487,22 @@ class TestSDPA(NNTestCase): assert torch._fused_sdp_choice(query, key, value) == ( SDPBackend.EFFICIENT_ATTENTION if warn_only else SDPBackend.MATH) - @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "CUDA unavailable") - def test_sdp_runtime_dispatch(self): - # We will test all the constraints that we know will cause a failure - # The problem is that any code path that goes down flash_attention - # will fail on CI/CD becuase it is not compiled with the right flags + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86Device, "CUDA unavailable") + def test_memory_efficeint_sm86_failure(self): device = 'cuda' dtype = torch.float16 make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=dtype) - if isSM86Device: - # See check_gpu_sm86_head_dim_128 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h - size = (2, 2, 4, 128) - q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) - with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False): - self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( - q, k, v, None, 0.0, False)) + # See check_gpu_sm86_head_dim_128 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h + size = (2, 2, 4, 128) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) + with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False): + self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, None, 0.0, False)) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention") + def test_dispatch_fails_no_backend(self): + dtype = torch.float16 + device = "cuda" with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False): size = (2, 3, 4) q = torch.randn(size, device=device, dtype=dtype) @@ -1506,42 +1513,92 @@ class TestSDPA(NNTestCase): self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.", lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v)) - if SM80OrLater: - with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False): - # Failures for invalid input + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention") + @parametrize( + "kernel", + [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + if SM80OrLater + else [SDPBackend.EFFICIENT_ATTENTION], + ) + def test_invalid_fused_inputs_dim_3(self, kernel: SDPBackend): + with sdp_kernel(**self.backend_map[kernel]): + # Dim is not 4 + device = "cuda" + size = (2, 3, 8) + dtype = torch.float16 + q = torch.randn(size, device=device, dtype=dtype) + k = torch.randn(size, device=device, dtype=dtype) + v = torch.randn(size, device=device, dtype=dtype) + self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, None, 0.0, False)) - # Dim is not 4 - q = torch.randn(size, device=device, dtype=dtype) - k = torch.randn(size, device=device, dtype=dtype) - v = torch.randn(size, device=device, dtype=dtype) - self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( - q, k, v, None, 0.0, False)) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention") + @parametrize( + "kernel", + [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + if SM80OrLater + else [SDPBackend.EFFICIENT_ATTENTION], + ) + def test_invalid_fused_inputs_broadcast(self, kernel: SDPBackend): + with sdp_kernel(**self.backend_map[kernel]): + # Fused Kernels don't support broadcasting + device = "cuda" + dtype = torch.float16 + size = (2, 4, 3, 8) + size_broadcast = (1, 4, 3, 8) + q = torch.randn(size_broadcast, device=device, dtype=dtype) + k = torch.randn(size, device=device, dtype=dtype) + v = torch.randn(size, device=device, dtype=dtype) + self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, None, 0.0, False)) - # The embed dim per head is not divisible by 8 for flash attention - size = (2, 2, 3, 4) - q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) - self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( - q, k, v, None, 0.0, False)) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support fused scaled dot product attention") + @parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) + def test_invalid_fused_inputs_head_dim(self, kernel: SDPBackend): + with sdp_kernel(**self.backend_map[kernel]): + # The embed dim per head is not divisible by 8 for flash attention + device = "cuda" + dtype = torch.float16 + make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=dtype) + size = (2, 2, 3, 9) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) + self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, None, 0.0, False)) - # Invalid dtype for both Flash Attention and Mem Efficient Attention - size = (2, 2, 3, 16) - make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float64) - q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) - self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( - q, k, v, None, 0.0, False)) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention") + @parametrize( + "kernel", + [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + if SM80OrLater + else [SDPBackend.EFFICIENT_ATTENTION], + ) + def test_invalid_fused_inputs_invalid_dtype(self, kernel: SDPBackend): + with sdp_kernel(**self.backend_map[kernel]): + # Invalid dtype for both Flash Attention and Mem Efficient Attention + device = "cuda" + size = (2, 2, 3, 16) + make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float64) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) + self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, None, 0.0, False)) - # Invalid dtype for Flash Attention - make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float32) - q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) - self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( - q, k, v, None, 0.0, False)) - - # Failures for unsupported SDP args - q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) - - # Non-None attention mask - self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( - q, k, v, torch.ones_like(q), 0.0, False)) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention") + @parametrize( + "kernel", + [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] + if SM80OrLater + else [SDPBackend.EFFICIENT_ATTENTION], + ) + def test_invalid_fused_inputs_attn_mask_present(self, kernel: SDPBackend): + with sdp_kernel(**self.backend_map[kernel]): + # Failures for unsupported SDP args + device = "cuda" + size = (2, 2, 3, 16) + make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=torch.float16) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) + # Non-None attention mask + self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, torch.ones_like(q), 0.0, False)) @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "CUDA unavailable") def test_unaligned_tensors(self): @@ -1784,6 +1841,39 @@ class TestSDPA(NNTestCase): self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype), atol=grad_v_ref_atol, rtol=grad_v_ref_rtol) + @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) + @parametrize("device", ["cpu", "cuda"] if TEST_CUDA else ["cpu"]) + def test_invalid_inputs_different_datatypes(self, kernel: SDPBackend, device: str): + with sdp_kernel(**self.backend_map[kernel]): + # Different datatypes + shape = (1, 4, 8, 16) + query = torch.randn(shape, dtype=torch.float32, device=device) + key = torch.randn(shape, dtype=torch.float16, device=device) + value = torch.randn(shape, dtype=torch.float16, device=device) + self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) + + @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) + @parametrize("device", ["cpu", "cuda"] if TEST_CUDA else ["cpu"]) + def test_invalid_inputs_different_devices(self, kernel: SDPBackend, device: str): + # Different devices + shape = (1, 4, 8, 16) + if device == "cuda": + query = torch.randn(shape, dtype=torch.float32, device=device) + key = torch.randn(shape, dtype=torch.float16, device='cpu') + value = torch.randn(shape, dtype=torch.float16, device='cpu') + self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) + + @parametrize("kernel", [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]) + @parametrize("device", ["cpu", "cuda"] if TEST_CUDA else ["cpu"]) + def test_invalid_inputs_1_dimensional_inputs(self, kernel: SDPBackend, device: str): + with sdp_kernel(**self.backend_map[kernel]): + # 1 dimensional input + shape = (1, 4) + query = torch.randn(4, dtype=torch.float16, device=device) + key = torch.randn(shape, dtype=torch.float16, device=device) + value = torch.randn(shape, dtype=torch.float16, device=device) + self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value)) + # TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for # cross device / dtype testing. instantiate_parametrized_tests(TestTransformers) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 722073cb89c..3e92a332c9e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -7670,17 +7670,32 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ dim_4_q_shape = (batch, num_heads, seq_q, head_dim) dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) - qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape)] + broadcast_tuple = ((num_heads, seq_q, head_dim), (batch, num_heads, seq_kv, head_dim)) + + qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple] + samples = [] for qkv_shapes, is_causal, dropout_p in product( qkv_shapes, [True, False], [0.0, 0.5]): shape_q, shape_kv = qkv_shapes - yield SampleInput( + samples.append(SampleInput( make(shape_q), make(shape_kv), make(shape_kv), is_causal=is_causal, dropout_p=dropout_p - ) + )) + + # Add non standard shapes + diff_v_head_dim = SampleInput( + make((batch, num_heads, seq_q, head_dim)), + make((batch, num_heads, seq_kv, head_dim)), + make((batch, num_heads, seq_kv, head_dim + 8)), + is_causal=is_causal, + dropout_p=dropout_p + ) + samples.append(diff_v_head_dim) + + yield from samples def sample_inputs_pairwise_distance(op_info, device, dtype, requires_grad, **kwargs): make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)