diff --git a/test/inductor/test_online_softmax.py b/test/inductor/test_online_softmax.py index e038487a27b..156f696f153 100644 --- a/test/inductor/test_online_softmax.py +++ b/test/inductor/test_online_softmax.py @@ -1,16 +1,21 @@ # Owner(s): ["module: inductor"] +import math import os from triton.testing import do_bench import torch import torch._inductor.config as inductor_config +from torch._dynamo.utils import same from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code -from torch.testing._internal.common_utils import IS_LINUX, parametrize, instantiate_parametrized_tests -from torch.testing._internal.inductor_utils import HAS_CUDA, GPU_TYPE -from torch._dynamo.utils import same +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + IS_LINUX, + parametrize, +) +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" @@ -43,7 +48,7 @@ class TestOnlineSoftmax(TestCase): def test_log_softmax(self): self.do_test_acc_and_perf(torch.log_softmax) - def get_softmax_wrapper(self, V=50304, use_log_softmax=False): + def get_softmax_wrapper(self, V=50304, use_log_softmax=False, device=GPU_TYPE): N = 32 * 1024 @torch.compile @@ -53,7 +58,7 @@ class TestOnlineSoftmax(TestCase): else: return torch.softmax(x, dim=-1) - x = torch.randn(N, V, dtype=torch.bfloat16, device=GPU_TYPE) + x = torch.randn(N, V, dtype=torch.bfloat16, device=device) out, source_codes = run_and_get_code(f, x) return source_codes[0] @@ -70,6 +75,12 @@ class TestOnlineSoftmax(TestCase): self.assertEqual(wrapper_code.count("for r0_offset in"), 2) + def test_no_online_softmax_for_cpu(self): + code = self.get_softmax_wrapper(V=2048, device="cpu") + + # CPU need an explicit loop across different rows. + # For GPU, this is parallelized by the hardware. + self.assertEqual(code.count("for(int64_t"), 4) def test_codegen_softmax_persistent_reduction(self): """ @@ -78,11 +89,34 @@ class TestOnlineSoftmax(TestCase): wrapper_code = self.get_softmax_wrapper(1024) self.assertEqual(wrapper_code.count("for r0_offset in"), 0) - # This test only work if we use pattern matcher rather the decompose - # softmax/log_softmax specially. + @inductor_config.patch("triton.persistent_reductions", False) + def test_sdpa(self): + """ + Make sure online softmax here does not conflict with the sdpa + patterns. + """ + q, k, v = ( + torch.randn((4, 2, 16, 32), device=GPU_TYPE, dtype=torch.bfloat16) + for _ in range(3) + ) + + def f(q, k, v): + return ( + torch.matmul(q, k.transpose(-2, -1)) + .div(math.sqrt(k.shape[-1])) + .softmax(dim=-1) + .matmul(v) + ) + + opt_f = torch.compile(f) + ref = f(q, k, v) + act, (code,) = run_and_get_code(opt_f, q, k, v) + self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2)) + self.assertTrue("aten._scaled_dot_product_" in code) + @parametrize("nrow", [2, 2048]) @parametrize("dim", [-1, 0, 1]) - def no_test_prepare_softmax(self, dim, nrow): + def test_prepare_softmax(self, dim, nrow): def f(x, dim): xmax = x.amax(dim=dim, keepdim=True) xsum = (x - xmax).exp().sum(dim=dim, keepdim=True) @@ -101,6 +135,7 @@ class TestOnlineSoftmax(TestCase): expected_num_loop = 1 self.assertEqual(code.count("for r0_offset in"), expected_num_loop) + instantiate_parametrized_tests(TestOnlineSoftmax) if __name__ == "__main__": diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index a6c37cb54f6..19ceafc5e76 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -51,9 +51,6 @@ quantized = torch.ops.quantized _quantized = torch.ops._quantized quantized_decomposed = torch.ops.quantized_decomposed -import os -USE_DECOMP_FOR_ONLINE_SOFTMAX = os.getenv("USE_DECOMP_FOR_ONLINE_SOFTMAX", "1") == "1" - inductor_decompositions = get_decompositions( [ aten._adaptive_avg_pool2d_backward, @@ -71,6 +68,7 @@ inductor_decompositions = get_decompositions( aten.lcm, aten.leaky_relu, aten.linalg_vector_norm, + aten._log_softmax, aten.max_pool2d_with_indices_backward, aten._native_batch_norm_legit, aten._native_batch_norm_legit_functional, @@ -85,6 +83,7 @@ inductor_decompositions = get_decompositions( aten.nll_loss2d_backward, aten.permute_copy, aten.rrelu_with_noise_backward, + aten._softmax, aten.sin_, aten.sqrt_, out_dtype, @@ -95,7 +94,7 @@ inductor_decompositions = get_decompositions( aten.upsample_bilinear2d.vec, quantized.linear_dynamic_fp16_unpacked_weight, _quantized.wrapped_quantized_linear, - ] + ([] if USE_DECOMP_FOR_ONLINE_SOFTMAX else [aten._log_softmax, aten._softmax]) + ] ) decompositions = {**core_aten_decompositions(), **inductor_decompositions} @@ -120,12 +119,6 @@ decomps_to_exclude = [ aten.baddbmm, # upcasts to fp32, perf issue ] -if USE_DECOMP_FOR_ONLINE_SOFTMAX: - decomps_to_exclude += [ - aten._softmax, # inductor will override this rule - aten._log_softmax, # inductor will override this rule - ] - remove_decompositions(decompositions, decomps_to_exclude) @@ -1076,77 +1069,3 @@ def rrelu_with_noise_functional( else: negative_slope = (lower + upper) / 2 return aten.leaky_relu(self, negative_slope), torch.Tensor() - - -def _use_online_softmax(x: torch.Tensor, dim: int) -> bool: - if not config.online_softmax: - return False - - # Don't do online softmax for scalar or 1d tensor - if x.dim() < 2: - return False - - # Only do online softmax for GPU for now - if x.device.type != "cuda": - return False - - return True - - -if USE_DECOMP_FOR_ONLINE_SOFTMAX: - @register_decomposition(aten._softmax) - def _softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor: - # eager softmax returns a contiguous tensor. Ensure that decomp also returns - # a contiguous tensor. - x = x.contiguous() - if half_to_float: - assert x.dtype == torch.half - computation_dtype, result_dtype = utils.elementwise_dtypes( - x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - ) - x = x.to(computation_dtype) - if x.numel() == 0: - unnormalized = torch.exp(x) - result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) - elif not _use_online_softmax(x, dim): - # don't want to affect small softmax. That may inferfere with - # the attention patterns. - x_max = torch.amax(x, dim, keepdim=True) - unnormalized = torch.exp(x - x_max) - result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) - else: - x_max, t_sum = inductor_prims.online_softmax(x, dim) - result = torch.exp(x - x_max) / t_sum - - if not half_to_float: - result = result.to(result_dtype) - return result - - - @register_decomposition(aten._log_softmax) - def _log_softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor: - # eager log_softmax returns a contiguous tensor. Ensure that decomp also - # returns a contiguous tensor. - x = x.contiguous() - if half_to_float: - assert x.dtype == torch.half - computation_dtype, result_dtype = utils.elementwise_dtypes( - x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - ) - x = x.to(computation_dtype) - if x.numel() == 0: - shifted = x - shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True)) - elif not _use_online_softmax(x, dim): - x_max = torch.amax(x, dim, keepdim=True) - shifted = x - x_max - shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True)) - else: - x_max, t_sum = inductor_prims.online_softmax(x, dim) - shifted = x - x_max - shifted_logsumexp = torch.log(t_sum) - - result = shifted - shifted_logsumexp - if not half_to_float: - result = result.to(result_dtype) - return result diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 1dc2f78657b..f0dfb870410 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -21,6 +21,7 @@ from torch.utils._ordered_set import OrderedSet from .. import config from ..pattern_matcher import ( CallFunction, + fwd_only, init_once_fakemode, KeywordArg, Match, @@ -28,7 +29,6 @@ from ..pattern_matcher import ( PatternMatcherPass, register_graph_pattern, register_replacement, - fwd_only, stable_topological_sort, ) from .replace_random import replace_random_passes @@ -44,27 +44,38 @@ pass_patterns = [ PatternMatcherPass(), ] -import os -USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX = os.getenv("USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX") == "1" -if USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX: - def prepare_softmax_pattern(x, dim): - xmax = x.amax(dim=dim, keepdim=True) - xsub = x - xmax - xexp = xsub.exp() - xsum = xexp.sum(dim=dim, keepdim=True) - return xmax, xsum, xsub, xexp - - def prepare_softmax_replacement(x, dim): - """ - Return xsub since otherwise log-softmax can not be matched - due to a use of this intermediate node. Same reason to return - xsub.exp() for softmax. - """ - from torch._inductor.inductor_prims import prepare_softmax_online - xmax, xsum = prepare_softmax_online(x, dim) - xsub = x - xmax - return xmax, xsum, xsub, xsub.exp() +def prepare_softmax_pattern(x, dim): + xmax = x.amax(dim=dim, keepdim=True) + xsub = x - xmax + xexp = xsub.exp() + xsum = xexp.sum(dim=dim, keepdim=True) + return xmax, xsum, xsub, xexp + + +def prepare_softmax_replacement(x, dim): + """ + Return xsub since otherwise log-softmax can not be matched + due to a use of this intermediate node. Same reason to return + xsub.exp() for softmax. + """ + from torch._inductor.inductor_prims import prepare_softmax_online + + xmax, xsum = prepare_softmax_online(x, dim) + xsub = x - xmax + return xmax, xsum, xsub, xsub.exp() + + +def prepare_softmax_extra_check(match): + """ + We only have triton online softmax kernels currently. + """ + return ( + config.online_softmax + and match.kwargs["x"].meta["val"].device.type == "cuda" + and config.cuda_backend == "triton" + ) + @init_once_fakemode def lazy_init(): @@ -76,16 +87,15 @@ def lazy_init(): _sfdp_init() _misc_patterns_init() - if USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX: - register_replacement( - prepare_softmax_pattern, - prepare_softmax_replacement, - [torch.empty(4, 8)], - scalar_workaround=dict(dim=-1), - trace_fn=fwd_only, - pass_dicts=pass_patterns[1], - extra_check=lambda *args, **kwargs: config.online_softmax - ) + register_replacement( + prepare_softmax_pattern, + prepare_softmax_replacement, + [torch.empty(4, 8)], + scalar_workaround=dict(dim=-1), + trace_fn=fwd_only, + pass_dicts=pass_patterns[1], + extra_check=prepare_softmax_extra_check, + ) def remove_no_ops( diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index dabb67ca577..8d1302a1b36 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -17,8 +17,12 @@ from torch.utils._sympy.symbol import SymT from . import config, dependencies from .codegen.common import index_prevent_reordering -from .ops_handler import WrapperHandler -from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs, reduction_num_outputs +from .utils import ( + cache_on_self, + reduction_num_outputs, + sympy_index_symbol_with_prefix, + sympy_subs, +) from .virtualized import ops, V @@ -475,10 +479,9 @@ class LoopBodyBlock: ) return self._inner.store_reduction(name, index, value) - def reduction(self, dtype, src_dtype, reduction_type, value): result = self._inner.reduction(dtype, src_dtype, reduction_type, value) - + num_outputs = reduction_num_outputs(reduction_type) if num_outputs > 1: return tuple(result[i] for i in range(num_outputs))