Update on "[inductor] online softmax"

Softmax need do some preparation work that access the input tensor in two passes
- compute amax of each row
- compute (x - amax).exp.sum for each row

When the row size is large, cache can not hold all the active data and accessing the input multiple passes increases execution time since the kernel is membw bounded.

Online softmax uses a customized reduction to compute max and sum at the same time by accessing the data in one pass. Check this paper for more details ( https://arxiv.org/abs/1805.02867 ). 

Also here is an online softmax kernel generated by inductor as a reference: https://gist.github.com/shunting314/67ae4fffd45d4f2753c781780332fa54

## Microbenchmark

- `TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_ONLINE_SOFTMAX=0 DO_PERF_TEST=1 python test/inductor/test_online_softmax.py -k test_softmax` : without online softmax
  - eager_ms=6.671296119689941
  - opt_ms=8.06931209564209
- `TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_ONLINE_SOFTMAX=1 DO_PERF_TEST=1 python test/inductor/test_online_softmax.py -k test_softmax`: with online softmax
  - eager_ms=6.634047985076904
  - opt_ms=6.230591773986816

Ideally, online softmax should save about 2ms here. We saves about 1.84ms in practice.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov peterbell10 ngimel 

[ghstack-poisoned]
This commit is contained in:
Shunting Zhang 2025-02-07 14:00:17 -08:00
commit 426e1afd11
4 changed files with 94 additions and 127 deletions

View file

@ -1,16 +1,21 @@
# Owner(s): ["module: inductor"]
import math
import os
from triton.testing import do_bench
import torch
import torch._inductor.config as inductor_config
from torch._dynamo.utils import same
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_utils import IS_LINUX, parametrize, instantiate_parametrized_tests
from torch.testing._internal.inductor_utils import HAS_CUDA, GPU_TYPE
from torch._dynamo.utils import same
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_LINUX,
parametrize,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
@ -43,7 +48,7 @@ class TestOnlineSoftmax(TestCase):
def test_log_softmax(self):
self.do_test_acc_and_perf(torch.log_softmax)
def get_softmax_wrapper(self, V=50304, use_log_softmax=False):
def get_softmax_wrapper(self, V=50304, use_log_softmax=False, device=GPU_TYPE):
N = 32 * 1024
@torch.compile
@ -53,7 +58,7 @@ class TestOnlineSoftmax(TestCase):
else:
return torch.softmax(x, dim=-1)
x = torch.randn(N, V, dtype=torch.bfloat16, device=GPU_TYPE)
x = torch.randn(N, V, dtype=torch.bfloat16, device=device)
out, source_codes = run_and_get_code(f, x)
return source_codes[0]
@ -70,6 +75,12 @@ class TestOnlineSoftmax(TestCase):
self.assertEqual(wrapper_code.count("for r0_offset in"), 2)
def test_no_online_softmax_for_cpu(self):
code = self.get_softmax_wrapper(V=2048, device="cpu")
# CPU need an explicit loop across different rows.
# For GPU, this is parallelized by the hardware.
self.assertEqual(code.count("for(int64_t"), 4)
def test_codegen_softmax_persistent_reduction(self):
"""
@ -78,11 +89,34 @@ class TestOnlineSoftmax(TestCase):
wrapper_code = self.get_softmax_wrapper(1024)
self.assertEqual(wrapper_code.count("for r0_offset in"), 0)
# This test only work if we use pattern matcher rather the decompose
# softmax/log_softmax specially.
@inductor_config.patch("triton.persistent_reductions", False)
def test_sdpa(self):
"""
Make sure online softmax here does not conflict with the sdpa
patterns.
"""
q, k, v = (
torch.randn((4, 2, 16, 32), device=GPU_TYPE, dtype=torch.bfloat16)
for _ in range(3)
)
def f(q, k, v):
return (
torch.matmul(q, k.transpose(-2, -1))
.div(math.sqrt(k.shape[-1]))
.softmax(dim=-1)
.matmul(v)
)
opt_f = torch.compile(f)
ref = f(q, k, v)
act, (code,) = run_and_get_code(opt_f, q, k, v)
self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2))
self.assertTrue("aten._scaled_dot_product_" in code)
@parametrize("nrow", [2, 2048])
@parametrize("dim", [-1, 0, 1])
def no_test_prepare_softmax(self, dim, nrow):
def test_prepare_softmax(self, dim, nrow):
def f(x, dim):
xmax = x.amax(dim=dim, keepdim=True)
xsum = (x - xmax).exp().sum(dim=dim, keepdim=True)
@ -101,6 +135,7 @@ class TestOnlineSoftmax(TestCase):
expected_num_loop = 1
self.assertEqual(code.count("for r0_offset in"), expected_num_loop)
instantiate_parametrized_tests(TestOnlineSoftmax)
if __name__ == "__main__":

View file

@ -51,9 +51,6 @@ quantized = torch.ops.quantized
_quantized = torch.ops._quantized
quantized_decomposed = torch.ops.quantized_decomposed
import os
USE_DECOMP_FOR_ONLINE_SOFTMAX = os.getenv("USE_DECOMP_FOR_ONLINE_SOFTMAX", "1") == "1"
inductor_decompositions = get_decompositions(
[
aten._adaptive_avg_pool2d_backward,
@ -71,6 +68,7 @@ inductor_decompositions = get_decompositions(
aten.lcm,
aten.leaky_relu,
aten.linalg_vector_norm,
aten._log_softmax,
aten.max_pool2d_with_indices_backward,
aten._native_batch_norm_legit,
aten._native_batch_norm_legit_functional,
@ -85,6 +83,7 @@ inductor_decompositions = get_decompositions(
aten.nll_loss2d_backward,
aten.permute_copy,
aten.rrelu_with_noise_backward,
aten._softmax,
aten.sin_,
aten.sqrt_,
out_dtype,
@ -95,7 +94,7 @@ inductor_decompositions = get_decompositions(
aten.upsample_bilinear2d.vec,
quantized.linear_dynamic_fp16_unpacked_weight,
_quantized.wrapped_quantized_linear,
] + ([] if USE_DECOMP_FOR_ONLINE_SOFTMAX else [aten._log_softmax, aten._softmax])
]
)
decompositions = {**core_aten_decompositions(), **inductor_decompositions}
@ -120,12 +119,6 @@ decomps_to_exclude = [
aten.baddbmm, # upcasts to fp32, perf issue
]
if USE_DECOMP_FOR_ONLINE_SOFTMAX:
decomps_to_exclude += [
aten._softmax, # inductor will override this rule
aten._log_softmax, # inductor will override this rule
]
remove_decompositions(decompositions, decomps_to_exclude)
@ -1076,77 +1069,3 @@ def rrelu_with_noise_functional(
else:
negative_slope = (lower + upper) / 2
return aten.leaky_relu(self, negative_slope), torch.Tensor()
def _use_online_softmax(x: torch.Tensor, dim: int) -> bool:
if not config.online_softmax:
return False
# Don't do online softmax for scalar or 1d tensor
if x.dim() < 2:
return False
# Only do online softmax for GPU for now
if x.device.type != "cuda":
return False
return True
if USE_DECOMP_FOR_ONLINE_SOFTMAX:
@register_decomposition(aten._softmax)
def _softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor:
# eager softmax returns a contiguous tensor. Ensure that decomp also returns
# a contiguous tensor.
x = x.contiguous()
if half_to_float:
assert x.dtype == torch.half
computation_dtype, result_dtype = utils.elementwise_dtypes(
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
x = x.to(computation_dtype)
if x.numel() == 0:
unnormalized = torch.exp(x)
result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
elif not _use_online_softmax(x, dim):
# don't want to affect small softmax. That may inferfere with
# the attention patterns.
x_max = torch.amax(x, dim, keepdim=True)
unnormalized = torch.exp(x - x_max)
result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
else:
x_max, t_sum = inductor_prims.online_softmax(x, dim)
result = torch.exp(x - x_max) / t_sum
if not half_to_float:
result = result.to(result_dtype)
return result
@register_decomposition(aten._log_softmax)
def _log_softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor:
# eager log_softmax returns a contiguous tensor. Ensure that decomp also
# returns a contiguous tensor.
x = x.contiguous()
if half_to_float:
assert x.dtype == torch.half
computation_dtype, result_dtype = utils.elementwise_dtypes(
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
x = x.to(computation_dtype)
if x.numel() == 0:
shifted = x
shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
elif not _use_online_softmax(x, dim):
x_max = torch.amax(x, dim, keepdim=True)
shifted = x - x_max
shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
else:
x_max, t_sum = inductor_prims.online_softmax(x, dim)
shifted = x - x_max
shifted_logsumexp = torch.log(t_sum)
result = shifted - shifted_logsumexp
if not half_to_float:
result = result.to(result_dtype)
return result

View file

@ -21,6 +21,7 @@ from torch.utils._ordered_set import OrderedSet
from .. import config
from ..pattern_matcher import (
CallFunction,
fwd_only,
init_once_fakemode,
KeywordArg,
Match,
@ -28,7 +29,6 @@ from ..pattern_matcher import (
PatternMatcherPass,
register_graph_pattern,
register_replacement,
fwd_only,
stable_topological_sort,
)
from .replace_random import replace_random_passes
@ -44,27 +44,38 @@ pass_patterns = [
PatternMatcherPass(),
]
import os
USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX = os.getenv("USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX") == "1"
if USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX:
def prepare_softmax_pattern(x, dim):
xmax = x.amax(dim=dim, keepdim=True)
xsub = x - xmax
xexp = xsub.exp()
xsum = xexp.sum(dim=dim, keepdim=True)
return xmax, xsum, xsub, xexp
def prepare_softmax_replacement(x, dim):
"""
Return xsub since otherwise log-softmax can not be matched
due to a use of this intermediate node. Same reason to return
xsub.exp() for softmax.
"""
from torch._inductor.inductor_prims import prepare_softmax_online
xmax, xsum = prepare_softmax_online(x, dim)
xsub = x - xmax
return xmax, xsum, xsub, xsub.exp()
def prepare_softmax_pattern(x, dim):
xmax = x.amax(dim=dim, keepdim=True)
xsub = x - xmax
xexp = xsub.exp()
xsum = xexp.sum(dim=dim, keepdim=True)
return xmax, xsum, xsub, xexp
def prepare_softmax_replacement(x, dim):
"""
Return xsub since otherwise log-softmax can not be matched
due to a use of this intermediate node. Same reason to return
xsub.exp() for softmax.
"""
from torch._inductor.inductor_prims import prepare_softmax_online
xmax, xsum = prepare_softmax_online(x, dim)
xsub = x - xmax
return xmax, xsum, xsub, xsub.exp()
def prepare_softmax_extra_check(match):
"""
We only have triton online softmax kernels currently.
"""
return (
config.online_softmax
and match.kwargs["x"].meta["val"].device.type == "cuda"
and config.cuda_backend == "triton"
)
@init_once_fakemode
def lazy_init():
@ -76,16 +87,15 @@ def lazy_init():
_sfdp_init()
_misc_patterns_init()
if USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX:
register_replacement(
prepare_softmax_pattern,
prepare_softmax_replacement,
[torch.empty(4, 8)],
scalar_workaround=dict(dim=-1),
trace_fn=fwd_only,
pass_dicts=pass_patterns[1],
extra_check=lambda *args, **kwargs: config.online_softmax
)
register_replacement(
prepare_softmax_pattern,
prepare_softmax_replacement,
[torch.empty(4, 8)],
scalar_workaround=dict(dim=-1),
trace_fn=fwd_only,
pass_dicts=pass_patterns[1],
extra_check=prepare_softmax_extra_check,
)
def remove_no_ops(

View file

@ -17,8 +17,12 @@ from torch.utils._sympy.symbol import SymT
from . import config, dependencies
from .codegen.common import index_prevent_reordering
from .ops_handler import WrapperHandler
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs, reduction_num_outputs
from .utils import (
cache_on_self,
reduction_num_outputs,
sympy_index_symbol_with_prefix,
sympy_subs,
)
from .virtualized import ops, V
@ -475,10 +479,9 @@ class LoopBodyBlock:
)
return self._inner.store_reduction(name, index, value)
def reduction(self, dtype, src_dtype, reduction_type, value):
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
num_outputs = reduction_num_outputs(reduction_type)
if num_outputs > 1:
return tuple(result[i] for i in range(num_outputs))