mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[inductor] use ftz variant of exp (#146216)"
This reverts commitb0b3fe8bcf. Reverted https://github.com/pytorch/pytorch/pull/146216 on behalf of https://github.com/atalman due to inductor/test_op_completeness.py::TestOpCompleteness::test_triton_overrides [GH job link](https://github.com/pytorch/pytorch/actions/runs/13152430750/job/36702812599) [HUD commit link](b0b3fe8bcf) ([comment](https://github.com/pytorch/pytorch/pull/146216#issuecomment-2636961317))
This commit is contained in:
parent
8a2000fd42
commit
282d185ec1
3 changed files with 1 additions and 61 deletions
|
|
@ -128,7 +128,6 @@ from torch.testing._internal.inductor_utils import (
|
|||
skipCPUIf,
|
||||
skipCUDAIf,
|
||||
)
|
||||
from torch.testing._internal.triton_utils import requires_cuda
|
||||
|
||||
|
||||
HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines
|
||||
|
|
@ -12364,50 +12363,6 @@ class CommonTemplate:
|
|||
|
||||
torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@requires_cuda
|
||||
@config.patch(use_fast_math=True)
|
||||
def test_prepare_softmax_with_fast_math(self):
|
||||
"""
|
||||
Measure on a A100, perf is 3.487ms v.s. 3.358ms without or with flushing to zero. A 4% speedup.
|
||||
"""
|
||||
if DO_PERF_TEST:
|
||||
M = 32768
|
||||
N = 50304
|
||||
else:
|
||||
# Use small shapes if not doing perf test
|
||||
M = 128
|
||||
N = 128
|
||||
x = torch.randn(M, N, dtype=torch.bfloat16, device=GPU_TYPE)
|
||||
|
||||
def f(x):
|
||||
"""
|
||||
Not calling softmax directly to generate kernel just for
|
||||
computation of max & sum.
|
||||
|
||||
If we call softmax directly, the computation of the final
|
||||
result will double the membw usage. In that case saving
|
||||
computation does not matter much.
|
||||
|
||||
In reality during training, since max & sum need to be saved
|
||||
for bwd and the computation of softmax result is fused with
|
||||
other kernels, we do see such prepare_softmax kernel appear
|
||||
in real models.
|
||||
"""
|
||||
x_max = x.amax(dim=-1, keepdim=True)
|
||||
x_sum = (x - x_max).exp().sum(dim=-1, keepdim=True).log()
|
||||
return x_max, x_sum
|
||||
|
||||
opt_f = torch.compile(f)
|
||||
ref = f(x)
|
||||
act = opt_f(x)
|
||||
self.assertTrue(same(ref, act, tol=1e-2), f"Ref:\n{ref}\nAct:\n{act}")
|
||||
|
||||
if DO_PERF_TEST:
|
||||
from triton.testing import do_bench
|
||||
|
||||
ms = do_bench(lambda: opt_f(x))
|
||||
print(f"{ms=:.3f}")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestFailure:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import dataclasses
|
|||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import textwrap
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
|
@ -817,8 +816,6 @@ def maybe_upcast_float32(convert_output: bool = True) -> Callable[[_T], _T]:
|
|||
class TritonOverrides(OpOverrides):
|
||||
"""Map element-wise ops to Triton"""
|
||||
|
||||
LOG_2_E = math.log2(math.e)
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(
|
||||
x,
|
||||
|
|
@ -934,17 +931,7 @@ class TritonOverrides(OpOverrides):
|
|||
@staticmethod
|
||||
@maybe_upcast_float32()
|
||||
def exp(x):
|
||||
"""
|
||||
When use_fast_math, use the ftz (flushing to zero) variant
|
||||
of exponent computation.
|
||||
|
||||
Check https://github.com/triton-lang/triton/issues/5735 for
|
||||
more details.
|
||||
"""
|
||||
if config.use_fast_math:
|
||||
return f"libdevice.exp2({x} * {TritonOverrides.LOG_2_E})"
|
||||
else:
|
||||
return f"tl_math.exp({x})"
|
||||
return f"tl_math.exp({x})"
|
||||
|
||||
@staticmethod
|
||||
@maybe_upcast_float32()
|
||||
|
|
|
|||
|
|
@ -148,8 +148,6 @@ allow_buffer_reuse = True
|
|||
# Enable pooled allocations for non-output tensors
|
||||
memory_planning = os.environ.get("TORCHINDUCTOR_MEMORY_PLANNING", "0") == "1"
|
||||
|
||||
use_fast_math = os.environ.get("TORCHINDUCTOR_USE_FAST_MATH") == "1"
|
||||
|
||||
# How to organize memory under memory_planning=True:
|
||||
# - "none": do not try to pool storage, just reuse
|
||||
# - "intermediates": all non-outputs share storage, outputs each get unique storage
|
||||
|
|
|
|||
Loading…
Reference in a new issue