Revert "[export] don't decompose custom triton op when exporting (#142426)"

This reverts commit 10b9c5944e.

Reverted https://github.com/pytorch/pytorch/pull/142426 on behalf of https://github.com/huydhn due to This fails one internal MTIA test, checking with the author that we need to revert and reland this ([comment](https://github.com/pytorch/pytorch/pull/142426#issuecomment-2555793496))
This commit is contained in:
PyTorch MergeBot 2024-12-19 21:21:38 +00:00
parent fc03c62c56
commit e9bd74d763
2 changed files with 6 additions and 164 deletions

View file

@ -72,8 +72,6 @@ from torch.testing._internal.common_utils import (
TEST_TRANSFORMERS,
TestCase as TorchTestCase,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.testing._internal.triton_utils import requires_gpu
from torch.utils._pytree import (
LeafSpec,
tree_flatten,
@ -85,12 +83,6 @@ from torch.utils._pytree import (
)
if HAS_GPU:
import triton
import triton.language as tl
from torch._library import capture_triton
try:
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@ -641,137 +633,6 @@ graph():
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
self.assertEqual(gm(*args), m(*args))
@requires_gpu
def test_export_custom_triton_kernel(self):
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.library.triton_op("mylib::add", mutates_args=())
def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
return output
class M(torch.nn.Module):
def forward(self, x, y):
return custom_add(x, y)
args = (
torch.randn(3, device=GPU_TYPE),
torch.randn(3, device=GPU_TYPE),
)
max_len = 128
dynamic_shapes = {
"x": {0: Dim("dim0_x", max=max_len)},
"y": {0: Dim("dim0_y", max=max_len)},
}
m = M()
ep = export(m, args, dynamic_shapes=dynamic_shapes)
FileCheck().check_count("torch.ops.mylib.add", 1, exactly=True).run(
ep.graph_module.code
)
ep_decomposed = ep.run_decompositions()
FileCheck().check_count("torch.ops.mylib.add", 1, exactly=True).run(
ep_decomposed.graph_module.code
)
exp_out = m(*args)
self.assertEqual(exp_out, ep.module()(*args))
@requires_gpu
def test_export_custom_triton_kernel_mutable(self):
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.library.triton_op("mylib::add", mutates_args={"output"})
def custom_add_out(
x: torch.Tensor, y: torch.Tensor, output: torch.Tensor
) -> torch.Tensor:
n_elements = output.numel()
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
return output.clone()
class M(torch.nn.Module):
def forward(self, x, y, out):
return custom_add_out(x, y, out)
args = (
torch.randn(3, device=GPU_TYPE),
torch.randn(3, device=GPU_TYPE),
torch.zeros(3, device=GPU_TYPE),
)
custom_add_out(*args)
max_len = 128
dynamic_shapes = {
"x": {0: Dim("dim0_x", max=max_len)},
"y": {0: Dim("dim0_y", max=max_len)},
"out": {0: Dim("dim0_z", max=max_len)},
}
m = M()
ep = export(m, args, dynamic_shapes=dynamic_shapes)
FileCheck().check_count("torch.ops.mylib.add", 1, exactly=True).run(
ep.graph_module.code
)
ep_decomposed = ep.run_decompositions()
FileCheck().check_count(
"torch.ops.higher_order.auto_functionalized", 1, exactly=True
).run(ep_decomposed.graph_module.code)
x, y, out = (
torch.randn(3, device=GPU_TYPE),
torch.randn(3, device=GPU_TYPE),
torch.zeros(3, device=GPU_TYPE),
)
exp_out = m(x, y, out)
out_copy = out.clone()
out_copy2 = out.clone()
out_copy3 = out.clone()
self.assertEqual(exp_out, ep.module()(x, y, out_copy))
# For non-functional graph module, out_copy is mutated
self.assertEqual(out, out_copy)
self.assertEqual(exp_out, ep_decomposed.module()(x, y, out_copy2))
# For non-functional graph module, out_copy is not mutated
self.assertEqual(out_copy2, out_copy3)
def test_masked_select_dynamic(self):
class M(torch.nn.Module):
def __init__(self) -> None:

View file

@ -130,33 +130,14 @@ def triton_op(
# - With torch.compile, this means that the backend (usually Inductor)
# can see a call to the triton kernel(s) and so it can directly optimize
# them by inlining them into the lowering process.
# - With post-dispatch torch.export, this means that there will
# be a call(s) to the triton_kernel_wrapper_functional HOP in the
# graph (that we have yet to figure out how to serialize).
def functional_decomp( # type: ignore[no-untyped-def]
mode, op, types, args, kwargs
mode, _, types, args, kwargs
):
# NOTE [Export custom triton op]
# For torch.export (strict and non-strict), we don't do functional decomposition.
# Instead, we preserve the custom triton ops as custom ops. This is because we want
# the exported program to be high-level and serializable. If we decompose
# the custom op to a functional hop and make it a node in exported program,
# we need to figure out ways of serializing the hop and its arguments, which can be triton.jited
# functions and triton dtypes. This is undesireble because:
# - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes.
# - exported program will contain the implementation detail (e.g. triton source code) for a specific
# backend (GPU), which is probably at a wrong level of abstraction.
# - changes to triton or the serialization logic for triton arguments can be BC breaking
#
# In the short term, we expect users to have a seperate aot_compile stage that compiles the exported program
# into a Cubin file on the same machine that users call export, which does autotuning and removes triton
# dependency and serve the model with Cubin. This guarantees that triton changes won't break BC.
# In the long term, we may export multiple cubins for the triton op directly.
from torch.compiler import is_exporting
if is_exporting():
return mode.__torch_dispatch__(op, types, args, kwargs)
else:
with mode:
return fn(*args, **kwargs)
with mode:
return fn(*args, **kwargs)
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
return result