diff --git a/test/export/test_export.py b/test/export/test_export.py index cbadc67e3a9..1ed4d9df356 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -72,8 +72,6 @@ from torch.testing._internal.common_utils import ( TEST_TRANSFORMERS, TestCase as TorchTestCase, ) -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU -from torch.testing._internal.triton_utils import requires_gpu from torch.utils._pytree import ( LeafSpec, tree_flatten, @@ -85,12 +83,6 @@ from torch.utils._pytree import ( ) -if HAS_GPU: - import triton - import triton.language as tl - - from torch._library import capture_triton - try: from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -641,137 +633,6 @@ graph(): args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256)) self.assertEqual(gm(*args), m(*args)) - @requires_gpu - def test_export_custom_triton_kernel(self): - @triton.jit - def add_kernel( - in_ptr0, - in_ptr1, - out_ptr, - n_elements, - BLOCK_SIZE: "tl.constexpr", - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(in_ptr0 + offsets, mask=mask) - y = tl.load(in_ptr1 + offsets, mask=mask) - output = x + y - tl.store(out_ptr + offsets, output, mask=mask) - - @torch.library.triton_op("mylib::add", mutates_args=()) - def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - output = torch.empty_like(x) - n_elements = output.numel() - - def grid(meta): - return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - - capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) - return output - - class M(torch.nn.Module): - def forward(self, x, y): - return custom_add(x, y) - - args = ( - torch.randn(3, device=GPU_TYPE), - torch.randn(3, device=GPU_TYPE), - ) - max_len = 128 - dynamic_shapes = { - "x": {0: Dim("dim0_x", max=max_len)}, - "y": {0: Dim("dim0_y", max=max_len)}, - } - m = M() - ep = export(m, args, dynamic_shapes=dynamic_shapes) - - FileCheck().check_count("torch.ops.mylib.add", 1, exactly=True).run( - ep.graph_module.code - ) - ep_decomposed = ep.run_decompositions() - FileCheck().check_count("torch.ops.mylib.add", 1, exactly=True).run( - ep_decomposed.graph_module.code - ) - exp_out = m(*args) - self.assertEqual(exp_out, ep.module()(*args)) - - @requires_gpu - def test_export_custom_triton_kernel_mutable(self): - @triton.jit - def add_kernel( - in_ptr0, - in_ptr1, - out_ptr, - n_elements, - BLOCK_SIZE: "tl.constexpr", - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(in_ptr0 + offsets, mask=mask) - y = tl.load(in_ptr1 + offsets, mask=mask) - output = x + y - tl.store(out_ptr + offsets, output, mask=mask) - - @torch.library.triton_op("mylib::add", mutates_args={"output"}) - def custom_add_out( - x: torch.Tensor, y: torch.Tensor, output: torch.Tensor - ) -> torch.Tensor: - n_elements = output.numel() - - def grid(meta): - return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - - capture_triton(add_kernel)[grid](x, y, output, n_elements, 16) - return output.clone() - - class M(torch.nn.Module): - def forward(self, x, y, out): - return custom_add_out(x, y, out) - - args = ( - torch.randn(3, device=GPU_TYPE), - torch.randn(3, device=GPU_TYPE), - torch.zeros(3, device=GPU_TYPE), - ) - custom_add_out(*args) - max_len = 128 - dynamic_shapes = { - "x": {0: Dim("dim0_x", max=max_len)}, - "y": {0: Dim("dim0_y", max=max_len)}, - "out": {0: Dim("dim0_z", max=max_len)}, - } - - m = M() - ep = export(m, args, dynamic_shapes=dynamic_shapes) - - FileCheck().check_count("torch.ops.mylib.add", 1, exactly=True).run( - ep.graph_module.code - ) - ep_decomposed = ep.run_decompositions() - FileCheck().check_count( - "torch.ops.higher_order.auto_functionalized", 1, exactly=True - ).run(ep_decomposed.graph_module.code) - - x, y, out = ( - torch.randn(3, device=GPU_TYPE), - torch.randn(3, device=GPU_TYPE), - torch.zeros(3, device=GPU_TYPE), - ) - exp_out = m(x, y, out) - out_copy = out.clone() - out_copy2 = out.clone() - out_copy3 = out.clone() - self.assertEqual(exp_out, ep.module()(x, y, out_copy)) - # For non-functional graph module, out_copy is mutated - self.assertEqual(out, out_copy) - self.assertEqual(exp_out, ep_decomposed.module()(x, y, out_copy2)) - # For non-functional graph module, out_copy is not mutated - self.assertEqual(out_copy2, out_copy3) - def test_masked_select_dynamic(self): class M(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/_library/triton.py b/torch/_library/triton.py index 9acd8cd7eab..797f5533ec5 100644 --- a/torch/_library/triton.py +++ b/torch/_library/triton.py @@ -130,33 +130,14 @@ def triton_op( # - With torch.compile, this means that the backend (usually Inductor) # can see a call to the triton kernel(s) and so it can directly optimize # them by inlining them into the lowering process. + # - With post-dispatch torch.export, this means that there will + # be a call(s) to the triton_kernel_wrapper_functional HOP in the + # graph (that we have yet to figure out how to serialize). def functional_decomp( # type: ignore[no-untyped-def] - mode, op, types, args, kwargs + mode, _, types, args, kwargs ): - # NOTE [Export custom triton op] - # For torch.export (strict and non-strict), we don't do functional decomposition. - # Instead, we preserve the custom triton ops as custom ops. This is because we want - # the exported program to be high-level and serializable. If we decompose - # the custom op to a functional hop and make it a node in exported program, - # we need to figure out ways of serializing the hop and its arguments, which can be triton.jited - # functions and triton dtypes. This is undesireble because: - # - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes. - # - exported program will contain the implementation detail (e.g. triton source code) for a specific - # backend (GPU), which is probably at a wrong level of abstraction. - # - changes to triton or the serialization logic for triton arguments can be BC breaking - # - # In the short term, we expect users to have a seperate aot_compile stage that compiles the exported program - # into a Cubin file on the same machine that users call export, which does autotuning and removes triton - # dependency and serve the model with Cubin. This guarantees that triton changes won't break BC. - # In the long term, we may export multiple cubins for the triton op directly. - - from torch.compiler import is_exporting - - if is_exporting(): - return mode.__torch_dispatch__(op, types, args, kwargs) - else: - with mode: - return fn(*args, **kwargs) + with mode: + return fn(*args, **kwargs) result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) return result