mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[export] don't decompose custom triton op when exporting (#142426)"
This reverts commit 10b9c5944e.
Reverted https://github.com/pytorch/pytorch/pull/142426 on behalf of https://github.com/huydhn due to This fails one internal MTIA test, checking with the author that we need to revert and reland this ([comment](https://github.com/pytorch/pytorch/pull/142426#issuecomment-2555793496))
This commit is contained in:
parent
fc03c62c56
commit
e9bd74d763
2 changed files with 6 additions and 164 deletions
|
|
@ -72,8 +72,6 @@ from torch.testing._internal.common_utils import (
|
|||
TEST_TRANSFORMERS,
|
||||
TestCase as TorchTestCase,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
from torch.testing._internal.triton_utils import requires_gpu
|
||||
from torch.utils._pytree import (
|
||||
LeafSpec,
|
||||
tree_flatten,
|
||||
|
|
@ -85,12 +83,6 @@ from torch.utils._pytree import (
|
|||
)
|
||||
|
||||
|
||||
if HAS_GPU:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from torch._library import capture_triton
|
||||
|
||||
try:
|
||||
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
||||
|
||||
|
|
@ -641,137 +633,6 @@ graph():
|
|||
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
|
||||
self.assertEqual(gm(*args), m(*args))
|
||||
|
||||
@requires_gpu
|
||||
def test_export_custom_triton_kernel(self):
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
@torch.library.triton_op("mylib::add", mutates_args=())
|
||||
def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output = torch.empty_like(x)
|
||||
n_elements = output.numel()
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
|
||||
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
|
||||
return output
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return custom_add(x, y)
|
||||
|
||||
args = (
|
||||
torch.randn(3, device=GPU_TYPE),
|
||||
torch.randn(3, device=GPU_TYPE),
|
||||
)
|
||||
max_len = 128
|
||||
dynamic_shapes = {
|
||||
"x": {0: Dim("dim0_x", max=max_len)},
|
||||
"y": {0: Dim("dim0_y", max=max_len)},
|
||||
}
|
||||
m = M()
|
||||
ep = export(m, args, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
FileCheck().check_count("torch.ops.mylib.add", 1, exactly=True).run(
|
||||
ep.graph_module.code
|
||||
)
|
||||
ep_decomposed = ep.run_decompositions()
|
||||
FileCheck().check_count("torch.ops.mylib.add", 1, exactly=True).run(
|
||||
ep_decomposed.graph_module.code
|
||||
)
|
||||
exp_out = m(*args)
|
||||
self.assertEqual(exp_out, ep.module()(*args))
|
||||
|
||||
@requires_gpu
|
||||
def test_export_custom_triton_kernel_mutable(self):
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
@torch.library.triton_op("mylib::add", mutates_args={"output"})
|
||||
def custom_add_out(
|
||||
x: torch.Tensor, y: torch.Tensor, output: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
n_elements = output.numel()
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
|
||||
capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
|
||||
return output.clone()
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, out):
|
||||
return custom_add_out(x, y, out)
|
||||
|
||||
args = (
|
||||
torch.randn(3, device=GPU_TYPE),
|
||||
torch.randn(3, device=GPU_TYPE),
|
||||
torch.zeros(3, device=GPU_TYPE),
|
||||
)
|
||||
custom_add_out(*args)
|
||||
max_len = 128
|
||||
dynamic_shapes = {
|
||||
"x": {0: Dim("dim0_x", max=max_len)},
|
||||
"y": {0: Dim("dim0_y", max=max_len)},
|
||||
"out": {0: Dim("dim0_z", max=max_len)},
|
||||
}
|
||||
|
||||
m = M()
|
||||
ep = export(m, args, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
FileCheck().check_count("torch.ops.mylib.add", 1, exactly=True).run(
|
||||
ep.graph_module.code
|
||||
)
|
||||
ep_decomposed = ep.run_decompositions()
|
||||
FileCheck().check_count(
|
||||
"torch.ops.higher_order.auto_functionalized", 1, exactly=True
|
||||
).run(ep_decomposed.graph_module.code)
|
||||
|
||||
x, y, out = (
|
||||
torch.randn(3, device=GPU_TYPE),
|
||||
torch.randn(3, device=GPU_TYPE),
|
||||
torch.zeros(3, device=GPU_TYPE),
|
||||
)
|
||||
exp_out = m(x, y, out)
|
||||
out_copy = out.clone()
|
||||
out_copy2 = out.clone()
|
||||
out_copy3 = out.clone()
|
||||
self.assertEqual(exp_out, ep.module()(x, y, out_copy))
|
||||
# For non-functional graph module, out_copy is mutated
|
||||
self.assertEqual(out, out_copy)
|
||||
self.assertEqual(exp_out, ep_decomposed.module()(x, y, out_copy2))
|
||||
# For non-functional graph module, out_copy is not mutated
|
||||
self.assertEqual(out_copy2, out_copy3)
|
||||
|
||||
def test_masked_select_dynamic(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -130,33 +130,14 @@ def triton_op(
|
|||
# - With torch.compile, this means that the backend (usually Inductor)
|
||||
# can see a call to the triton kernel(s) and so it can directly optimize
|
||||
# them by inlining them into the lowering process.
|
||||
# - With post-dispatch torch.export, this means that there will
|
||||
# be a call(s) to the triton_kernel_wrapper_functional HOP in the
|
||||
# graph (that we have yet to figure out how to serialize).
|
||||
def functional_decomp( # type: ignore[no-untyped-def]
|
||||
mode, op, types, args, kwargs
|
||||
mode, _, types, args, kwargs
|
||||
):
|
||||
# NOTE [Export custom triton op]
|
||||
# For torch.export (strict and non-strict), we don't do functional decomposition.
|
||||
# Instead, we preserve the custom triton ops as custom ops. This is because we want
|
||||
# the exported program to be high-level and serializable. If we decompose
|
||||
# the custom op to a functional hop and make it a node in exported program,
|
||||
# we need to figure out ways of serializing the hop and its arguments, which can be triton.jited
|
||||
# functions and triton dtypes. This is undesireble because:
|
||||
# - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes.
|
||||
# - exported program will contain the implementation detail (e.g. triton source code) for a specific
|
||||
# backend (GPU), which is probably at a wrong level of abstraction.
|
||||
# - changes to triton or the serialization logic for triton arguments can be BC breaking
|
||||
#
|
||||
# In the short term, we expect users to have a seperate aot_compile stage that compiles the exported program
|
||||
# into a Cubin file on the same machine that users call export, which does autotuning and removes triton
|
||||
# dependency and serve the model with Cubin. This guarantees that triton changes won't break BC.
|
||||
# In the long term, we may export multiple cubins for the triton op directly.
|
||||
|
||||
from torch.compiler import is_exporting
|
||||
|
||||
if is_exporting():
|
||||
return mode.__torch_dispatch__(op, types, args, kwargs)
|
||||
else:
|
||||
with mode:
|
||||
return fn(*args, **kwargs)
|
||||
with mode:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
|
||||
return result
|
||||
|
|
|
|||
Loading…
Reference in a new issue