mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update on "[inductor] online softmax"
Softmax need do some preparation work that access the input tensor in two passes - compute amax of each row - compute (x - amax).exp.sum for each row When the row size is large, cache can not hold all the active data and accessing the input multiple passes increases execution time since the kernel is membw bounded. Online softmax uses a customized reduction to compute max and sum at the same time by accessing the data in one pass. Check this paper for more details ( https://arxiv.org/abs/1805.02867 ). Also here is an online softmax kernel generated by inductor as a reference: https://gist.github.com/shunting314/67ae4fffd45d4f2753c781780332fa54 ## Microbenchmark - `TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_ONLINE_SOFTMAX=0 DO_PERF_TEST=1 python test/inductor/test_online_softmax.py -k test_softmax` : without online softmax - eager_ms=6.671296119689941 - opt_ms=8.06931209564209 - `TORCHINDUCTOR_COORDINATE_DESCENT_TUNING=1 TORCHINDUCTOR_ONLINE_SOFTMAX=1 DO_PERF_TEST=1 python test/inductor/test_online_softmax.py -k test_softmax`: with online softmax - eager_ms=6.634047985076904 - opt_ms=6.230591773986816 Ideally, online softmax should save about 2ms here. We saves about 1.84ms in practice. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov peterbell10 ngimel [ghstack-poisoned]
This commit is contained in:
commit
426e1afd11
4 changed files with 94 additions and 127 deletions
|
|
@ -1,16 +1,21 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
from triton.testing import do_bench
|
||||
|
||||
import torch
|
||||
import torch._inductor.config as inductor_config
|
||||
from torch._dynamo.utils import same
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.common_utils import IS_LINUX, parametrize, instantiate_parametrized_tests
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA, GPU_TYPE
|
||||
from torch._dynamo.utils import same
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
IS_LINUX,
|
||||
parametrize,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA
|
||||
|
||||
|
||||
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"
|
||||
|
|
@ -43,7 +48,7 @@ class TestOnlineSoftmax(TestCase):
|
|||
def test_log_softmax(self):
|
||||
self.do_test_acc_and_perf(torch.log_softmax)
|
||||
|
||||
def get_softmax_wrapper(self, V=50304, use_log_softmax=False):
|
||||
def get_softmax_wrapper(self, V=50304, use_log_softmax=False, device=GPU_TYPE):
|
||||
N = 32 * 1024
|
||||
|
||||
@torch.compile
|
||||
|
|
@ -53,7 +58,7 @@ class TestOnlineSoftmax(TestCase):
|
|||
else:
|
||||
return torch.softmax(x, dim=-1)
|
||||
|
||||
x = torch.randn(N, V, dtype=torch.bfloat16, device=GPU_TYPE)
|
||||
x = torch.randn(N, V, dtype=torch.bfloat16, device=device)
|
||||
out, source_codes = run_and_get_code(f, x)
|
||||
return source_codes[0]
|
||||
|
||||
|
|
@ -70,6 +75,12 @@ class TestOnlineSoftmax(TestCase):
|
|||
|
||||
self.assertEqual(wrapper_code.count("for r0_offset in"), 2)
|
||||
|
||||
def test_no_online_softmax_for_cpu(self):
|
||||
code = self.get_softmax_wrapper(V=2048, device="cpu")
|
||||
|
||||
# CPU need an explicit loop across different rows.
|
||||
# For GPU, this is parallelized by the hardware.
|
||||
self.assertEqual(code.count("for(int64_t"), 4)
|
||||
|
||||
def test_codegen_softmax_persistent_reduction(self):
|
||||
"""
|
||||
|
|
@ -78,11 +89,34 @@ class TestOnlineSoftmax(TestCase):
|
|||
wrapper_code = self.get_softmax_wrapper(1024)
|
||||
self.assertEqual(wrapper_code.count("for r0_offset in"), 0)
|
||||
|
||||
# This test only work if we use pattern matcher rather the decompose
|
||||
# softmax/log_softmax specially.
|
||||
@inductor_config.patch("triton.persistent_reductions", False)
|
||||
def test_sdpa(self):
|
||||
"""
|
||||
Make sure online softmax here does not conflict with the sdpa
|
||||
patterns.
|
||||
"""
|
||||
q, k, v = (
|
||||
torch.randn((4, 2, 16, 32), device=GPU_TYPE, dtype=torch.bfloat16)
|
||||
for _ in range(3)
|
||||
)
|
||||
|
||||
def f(q, k, v):
|
||||
return (
|
||||
torch.matmul(q, k.transpose(-2, -1))
|
||||
.div(math.sqrt(k.shape[-1]))
|
||||
.softmax(dim=-1)
|
||||
.matmul(v)
|
||||
)
|
||||
|
||||
opt_f = torch.compile(f)
|
||||
ref = f(q, k, v)
|
||||
act, (code,) = run_and_get_code(opt_f, q, k, v)
|
||||
self.assertTrue(torch.allclose(ref, act, atol=1e-2, rtol=1e-2))
|
||||
self.assertTrue("aten._scaled_dot_product_" in code)
|
||||
|
||||
@parametrize("nrow", [2, 2048])
|
||||
@parametrize("dim", [-1, 0, 1])
|
||||
def no_test_prepare_softmax(self, dim, nrow):
|
||||
def test_prepare_softmax(self, dim, nrow):
|
||||
def f(x, dim):
|
||||
xmax = x.amax(dim=dim, keepdim=True)
|
||||
xsum = (x - xmax).exp().sum(dim=dim, keepdim=True)
|
||||
|
|
@ -101,6 +135,7 @@ class TestOnlineSoftmax(TestCase):
|
|||
expected_num_loop = 1
|
||||
self.assertEqual(code.count("for r0_offset in"), expected_num_loop)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestOnlineSoftmax)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -51,9 +51,6 @@ quantized = torch.ops.quantized
|
|||
_quantized = torch.ops._quantized
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
import os
|
||||
USE_DECOMP_FOR_ONLINE_SOFTMAX = os.getenv("USE_DECOMP_FOR_ONLINE_SOFTMAX", "1") == "1"
|
||||
|
||||
inductor_decompositions = get_decompositions(
|
||||
[
|
||||
aten._adaptive_avg_pool2d_backward,
|
||||
|
|
@ -71,6 +68,7 @@ inductor_decompositions = get_decompositions(
|
|||
aten.lcm,
|
||||
aten.leaky_relu,
|
||||
aten.linalg_vector_norm,
|
||||
aten._log_softmax,
|
||||
aten.max_pool2d_with_indices_backward,
|
||||
aten._native_batch_norm_legit,
|
||||
aten._native_batch_norm_legit_functional,
|
||||
|
|
@ -85,6 +83,7 @@ inductor_decompositions = get_decompositions(
|
|||
aten.nll_loss2d_backward,
|
||||
aten.permute_copy,
|
||||
aten.rrelu_with_noise_backward,
|
||||
aten._softmax,
|
||||
aten.sin_,
|
||||
aten.sqrt_,
|
||||
out_dtype,
|
||||
|
|
@ -95,7 +94,7 @@ inductor_decompositions = get_decompositions(
|
|||
aten.upsample_bilinear2d.vec,
|
||||
quantized.linear_dynamic_fp16_unpacked_weight,
|
||||
_quantized.wrapped_quantized_linear,
|
||||
] + ([] if USE_DECOMP_FOR_ONLINE_SOFTMAX else [aten._log_softmax, aten._softmax])
|
||||
]
|
||||
)
|
||||
decompositions = {**core_aten_decompositions(), **inductor_decompositions}
|
||||
|
||||
|
|
@ -120,12 +119,6 @@ decomps_to_exclude = [
|
|||
aten.baddbmm, # upcasts to fp32, perf issue
|
||||
]
|
||||
|
||||
if USE_DECOMP_FOR_ONLINE_SOFTMAX:
|
||||
decomps_to_exclude += [
|
||||
aten._softmax, # inductor will override this rule
|
||||
aten._log_softmax, # inductor will override this rule
|
||||
]
|
||||
|
||||
remove_decompositions(decompositions, decomps_to_exclude)
|
||||
|
||||
|
||||
|
|
@ -1076,77 +1069,3 @@ def rrelu_with_noise_functional(
|
|||
else:
|
||||
negative_slope = (lower + upper) / 2
|
||||
return aten.leaky_relu(self, negative_slope), torch.Tensor()
|
||||
|
||||
|
||||
def _use_online_softmax(x: torch.Tensor, dim: int) -> bool:
|
||||
if not config.online_softmax:
|
||||
return False
|
||||
|
||||
# Don't do online softmax for scalar or 1d tensor
|
||||
if x.dim() < 2:
|
||||
return False
|
||||
|
||||
# Only do online softmax for GPU for now
|
||||
if x.device.type != "cuda":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if USE_DECOMP_FOR_ONLINE_SOFTMAX:
|
||||
@register_decomposition(aten._softmax)
|
||||
def _softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor:
|
||||
# eager softmax returns a contiguous tensor. Ensure that decomp also returns
|
||||
# a contiguous tensor.
|
||||
x = x.contiguous()
|
||||
if half_to_float:
|
||||
assert x.dtype == torch.half
|
||||
computation_dtype, result_dtype = utils.elementwise_dtypes(
|
||||
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
)
|
||||
x = x.to(computation_dtype)
|
||||
if x.numel() == 0:
|
||||
unnormalized = torch.exp(x)
|
||||
result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
|
||||
elif not _use_online_softmax(x, dim):
|
||||
# don't want to affect small softmax. That may inferfere with
|
||||
# the attention patterns.
|
||||
x_max = torch.amax(x, dim, keepdim=True)
|
||||
unnormalized = torch.exp(x - x_max)
|
||||
result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
|
||||
else:
|
||||
x_max, t_sum = inductor_prims.online_softmax(x, dim)
|
||||
result = torch.exp(x - x_max) / t_sum
|
||||
|
||||
if not half_to_float:
|
||||
result = result.to(result_dtype)
|
||||
return result
|
||||
|
||||
|
||||
@register_decomposition(aten._log_softmax)
|
||||
def _log_softmax(x: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor:
|
||||
# eager log_softmax returns a contiguous tensor. Ensure that decomp also
|
||||
# returns a contiguous tensor.
|
||||
x = x.contiguous()
|
||||
if half_to_float:
|
||||
assert x.dtype == torch.half
|
||||
computation_dtype, result_dtype = utils.elementwise_dtypes(
|
||||
x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
)
|
||||
x = x.to(computation_dtype)
|
||||
if x.numel() == 0:
|
||||
shifted = x
|
||||
shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
|
||||
elif not _use_online_softmax(x, dim):
|
||||
x_max = torch.amax(x, dim, keepdim=True)
|
||||
shifted = x - x_max
|
||||
shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
|
||||
else:
|
||||
x_max, t_sum = inductor_prims.online_softmax(x, dim)
|
||||
shifted = x - x_max
|
||||
shifted_logsumexp = torch.log(t_sum)
|
||||
|
||||
result = shifted - shifted_logsumexp
|
||||
if not half_to_float:
|
||||
result = result.to(result_dtype)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from torch.utils._ordered_set import OrderedSet
|
|||
from .. import config
|
||||
from ..pattern_matcher import (
|
||||
CallFunction,
|
||||
fwd_only,
|
||||
init_once_fakemode,
|
||||
KeywordArg,
|
||||
Match,
|
||||
|
|
@ -28,7 +29,6 @@ from ..pattern_matcher import (
|
|||
PatternMatcherPass,
|
||||
register_graph_pattern,
|
||||
register_replacement,
|
||||
fwd_only,
|
||||
stable_topological_sort,
|
||||
)
|
||||
from .replace_random import replace_random_passes
|
||||
|
|
@ -44,27 +44,38 @@ pass_patterns = [
|
|||
PatternMatcherPass(),
|
||||
]
|
||||
|
||||
import os
|
||||
USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX = os.getenv("USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX") == "1"
|
||||
|
||||
if USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX:
|
||||
def prepare_softmax_pattern(x, dim):
|
||||
xmax = x.amax(dim=dim, keepdim=True)
|
||||
xsub = x - xmax
|
||||
xexp = xsub.exp()
|
||||
xsum = xexp.sum(dim=dim, keepdim=True)
|
||||
return xmax, xsum, xsub, xexp
|
||||
|
||||
def prepare_softmax_replacement(x, dim):
|
||||
"""
|
||||
Return xsub since otherwise log-softmax can not be matched
|
||||
due to a use of this intermediate node. Same reason to return
|
||||
xsub.exp() for softmax.
|
||||
"""
|
||||
from torch._inductor.inductor_prims import prepare_softmax_online
|
||||
xmax, xsum = prepare_softmax_online(x, dim)
|
||||
xsub = x - xmax
|
||||
return xmax, xsum, xsub, xsub.exp()
|
||||
def prepare_softmax_pattern(x, dim):
|
||||
xmax = x.amax(dim=dim, keepdim=True)
|
||||
xsub = x - xmax
|
||||
xexp = xsub.exp()
|
||||
xsum = xexp.sum(dim=dim, keepdim=True)
|
||||
return xmax, xsum, xsub, xexp
|
||||
|
||||
|
||||
def prepare_softmax_replacement(x, dim):
|
||||
"""
|
||||
Return xsub since otherwise log-softmax can not be matched
|
||||
due to a use of this intermediate node. Same reason to return
|
||||
xsub.exp() for softmax.
|
||||
"""
|
||||
from torch._inductor.inductor_prims import prepare_softmax_online
|
||||
|
||||
xmax, xsum = prepare_softmax_online(x, dim)
|
||||
xsub = x - xmax
|
||||
return xmax, xsum, xsub, xsub.exp()
|
||||
|
||||
|
||||
def prepare_softmax_extra_check(match):
|
||||
"""
|
||||
We only have triton online softmax kernels currently.
|
||||
"""
|
||||
return (
|
||||
config.online_softmax
|
||||
and match.kwargs["x"].meta["val"].device.type == "cuda"
|
||||
and config.cuda_backend == "triton"
|
||||
)
|
||||
|
||||
|
||||
@init_once_fakemode
|
||||
def lazy_init():
|
||||
|
|
@ -76,16 +87,15 @@ def lazy_init():
|
|||
_sfdp_init()
|
||||
_misc_patterns_init()
|
||||
|
||||
if USE_PATTERN_MATCHER_FOR_ONLINE_SOFTMAX:
|
||||
register_replacement(
|
||||
prepare_softmax_pattern,
|
||||
prepare_softmax_replacement,
|
||||
[torch.empty(4, 8)],
|
||||
scalar_workaround=dict(dim=-1),
|
||||
trace_fn=fwd_only,
|
||||
pass_dicts=pass_patterns[1],
|
||||
extra_check=lambda *args, **kwargs: config.online_softmax
|
||||
)
|
||||
register_replacement(
|
||||
prepare_softmax_pattern,
|
||||
prepare_softmax_replacement,
|
||||
[torch.empty(4, 8)],
|
||||
scalar_workaround=dict(dim=-1),
|
||||
trace_fn=fwd_only,
|
||||
pass_dicts=pass_patterns[1],
|
||||
extra_check=prepare_softmax_extra_check,
|
||||
)
|
||||
|
||||
|
||||
def remove_no_ops(
|
||||
|
|
|
|||
|
|
@ -17,8 +17,12 @@ from torch.utils._sympy.symbol import SymT
|
|||
|
||||
from . import config, dependencies
|
||||
from .codegen.common import index_prevent_reordering
|
||||
from .ops_handler import WrapperHandler
|
||||
from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs, reduction_num_outputs
|
||||
from .utils import (
|
||||
cache_on_self,
|
||||
reduction_num_outputs,
|
||||
sympy_index_symbol_with_prefix,
|
||||
sympy_subs,
|
||||
)
|
||||
from .virtualized import ops, V
|
||||
|
||||
|
||||
|
|
@ -475,10 +479,9 @@ class LoopBodyBlock:
|
|||
)
|
||||
return self._inner.store_reduction(name, index, value)
|
||||
|
||||
|
||||
def reduction(self, dtype, src_dtype, reduction_type, value):
|
||||
result = self._inner.reduction(dtype, src_dtype, reduction_type, value)
|
||||
|
||||
|
||||
num_outputs = reduction_num_outputs(reduction_type)
|
||||
if num_outputs > 1:
|
||||
return tuple(result[i] for i in range(num_outputs))
|
||||
|
|
|
|||
Loading…
Reference in a new issue