[inductor] Reland: Add flag to ignore unsupported @triton.autotune args in user-written kernel compilation (#132562)

Summary:
This is a reland attempt of [#131431](https://github.com/pytorch/pytorch/pull/131431), as, in its original form, the PR has caused issues internally.

We currently don't support some of the `triton.autotune` arguments when compiling user-written Triton kernels with PT2. In this PR, we're adding a flag to circumvent it. This is to unblock internal compilation in some cases. The flag is supplied with the docs mentioning why it is not a good idea to set it.

Test Plan:
```
python test/inductor/test_triton_kernels.py -k test_triton_kernel_
autotune_with_unsupported_args
...
----------------------------------------------------------------------
Ran 3 tests in 3.636s

OK
```

Differential Revision: D60701839

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132562
Approved by: https://github.com/chenyang78
This commit is contained in:
Adnan Akhundov 2024-08-03 06:31:27 +00:00 committed by PyTorch MergeBot
parent 06581c277a
commit 8ad9f89ccc
4 changed files with 84 additions and 18 deletions

View file

@ -673,6 +673,29 @@ def forward(self, x_1, output_1):
output2 = torch.zeros_like(t1, requires_grad=grad)
self.assertEqual(compiled_func(t1, t2, output2), torch_add)
@requires_gpu
@skipIfRocm # https://github.com/pytorch/pytorch/actions/runs/10051552819/job/27782048305?pr=131431
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
@patch.object(
torch._inductor.config, "unsafe_ignore_unsupported_triton_autotune_args", True
)
def test_triton_kernel_autotune_with_unsupported_args(self, backend):
def call_triton(x: torch.Tensor, y: torch.Tensor):
output = torch.zeros_like(x)
n_elements = output.numel()
add_kernel_autotuned_with_unsupported_args[(n_elements,)](
x, y, output, n_elements
)
return output
t1 = torch.rand(256, device=GPU_TYPE)
t2 = torch.rand(256, device=GPU_TYPE)
torch_add = call_triton(t1, t2)
compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
compiled_add = compiled_func(t1, t2)
self.assertEqual(compiled_add, torch_add)
@requires_gpu
@common_utils.parametrize("grad", [False, True])
@common_utils.parametrize("dynamic", [False, True])

View file

@ -814,26 +814,33 @@ class TritonHOPifier:
# Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
# The call to get_first_attr is to maintain backward-compatibility.
if (
(
"warmup" in defaults
and defaults["warmup"].default
!= torch._dynamo.utils.get_first_attr(
kernel, "num_warmups", "warmup"
not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args
and (
(
"warmup" in defaults
and defaults["warmup"].default
!= torch._dynamo.utils.get_first_attr(
kernel, "num_warmups", "warmup"
)
)
or (
"rep" in defaults
and defaults["rep"].default
!= torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep")
)
or (
"prune_configs_by" in defaults
and defaults["prune_configs_by"].default
!= kernel.early_config_prune
)
# Set via reset_to_zero argument
or len(kernel.reset_idx) != 0
or len(kernel.restore_idx) != 0
or (
"use_cuda_graph" in defaults
and defaults["use_cuda_graph"].default != kernel.use_cuda_graph
)
)
or (
"rep" in defaults
and defaults["rep"].default
!= torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep")
)
or (
"prune_configs_by" in defaults
and defaults["prune_configs_by"].default
!= kernel.early_config_prune
)
# Set via reset_to_zero argument
or len(kernel.reset_idx) != 0
or len(kernel.restore_idx) != 0
):
self.raise_unsupported(
"Only configs and keys are supported for triton.autotune"

View file

@ -647,6 +647,12 @@ decompose_mem_bound_mm: bool = False
# In the common case, most inputs will be aligned.
assume_aligned_inputs: bool = False
# For the user-written Triton kernels compiled with the model, ignore the unsupported
# arguments passed to the @triton.autotune in the user's code; this is unsafe, as
# ignoring the unsupported args may lead to unexpected autotuning behavior: don't
# set unless you know what you're doing.
unsafe_ignore_unsupported_triton_autotune_args: bool = False
# config specific to codegen/cpp.py
class cpp:

View file

@ -118,6 +118,36 @@ if has_triton():
tmp2 = tmp0 + tmp1
tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
def _dummy_early_config_prune(configs, *_, **__):
return configs
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=[],
warmup=10,
rep=20,
prune_configs_by={"early_config_prune": _dummy_early_config_prune},
)
@triton.jit
def add_kernel_autotuned_with_unsupported_args(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_with_scaling(
in_ptr0,