mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Reland: Add flag to ignore unsupported @triton.autotune args in user-written kernel compilation (#132562)
Summary: This is a reland attempt of [#131431](https://github.com/pytorch/pytorch/pull/131431), as, in its original form, the PR has caused issues internally. We currently don't support some of the `triton.autotune` arguments when compiling user-written Triton kernels with PT2. In this PR, we're adding a flag to circumvent it. This is to unblock internal compilation in some cases. The flag is supplied with the docs mentioning why it is not a good idea to set it. Test Plan: ``` python test/inductor/test_triton_kernels.py -k test_triton_kernel_ autotune_with_unsupported_args ... ---------------------------------------------------------------------- Ran 3 tests in 3.636s OK ``` Differential Revision: D60701839 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132562 Approved by: https://github.com/chenyang78
This commit is contained in:
parent
06581c277a
commit
8ad9f89ccc
4 changed files with 84 additions and 18 deletions
|
|
@ -673,6 +673,29 @@ def forward(self, x_1, output_1):
|
|||
output2 = torch.zeros_like(t1, requires_grad=grad)
|
||||
self.assertEqual(compiled_func(t1, t2, output2), torch_add)
|
||||
|
||||
@requires_gpu
|
||||
@skipIfRocm # https://github.com/pytorch/pytorch/actions/runs/10051552819/job/27782048305?pr=131431
|
||||
@common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"])
|
||||
@patch.object(
|
||||
torch._inductor.config, "unsafe_ignore_unsupported_triton_autotune_args", True
|
||||
)
|
||||
def test_triton_kernel_autotune_with_unsupported_args(self, backend):
|
||||
def call_triton(x: torch.Tensor, y: torch.Tensor):
|
||||
output = torch.zeros_like(x)
|
||||
n_elements = output.numel()
|
||||
add_kernel_autotuned_with_unsupported_args[(n_elements,)](
|
||||
x, y, output, n_elements
|
||||
)
|
||||
return output
|
||||
|
||||
t1 = torch.rand(256, device=GPU_TYPE)
|
||||
t2 = torch.rand(256, device=GPU_TYPE)
|
||||
|
||||
torch_add = call_triton(t1, t2)
|
||||
compiled_func = torch.compile(call_triton, backend=backend, fullgraph=True)
|
||||
compiled_add = compiled_func(t1, t2)
|
||||
self.assertEqual(compiled_add, torch_add)
|
||||
|
||||
@requires_gpu
|
||||
@common_utils.parametrize("grad", [False, True])
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
|
|
|
|||
|
|
@ -814,26 +814,33 @@ class TritonHOPifier:
|
|||
# Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
|
||||
# The call to get_first_attr is to maintain backward-compatibility.
|
||||
if (
|
||||
(
|
||||
"warmup" in defaults
|
||||
and defaults["warmup"].default
|
||||
!= torch._dynamo.utils.get_first_attr(
|
||||
kernel, "num_warmups", "warmup"
|
||||
not torch._inductor.config.unsafe_ignore_unsupported_triton_autotune_args
|
||||
and (
|
||||
(
|
||||
"warmup" in defaults
|
||||
and defaults["warmup"].default
|
||||
!= torch._dynamo.utils.get_first_attr(
|
||||
kernel, "num_warmups", "warmup"
|
||||
)
|
||||
)
|
||||
or (
|
||||
"rep" in defaults
|
||||
and defaults["rep"].default
|
||||
!= torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep")
|
||||
)
|
||||
or (
|
||||
"prune_configs_by" in defaults
|
||||
and defaults["prune_configs_by"].default
|
||||
!= kernel.early_config_prune
|
||||
)
|
||||
# Set via reset_to_zero argument
|
||||
or len(kernel.reset_idx) != 0
|
||||
or len(kernel.restore_idx) != 0
|
||||
or (
|
||||
"use_cuda_graph" in defaults
|
||||
and defaults["use_cuda_graph"].default != kernel.use_cuda_graph
|
||||
)
|
||||
)
|
||||
or (
|
||||
"rep" in defaults
|
||||
and defaults["rep"].default
|
||||
!= torch._dynamo.utils.get_first_attr(kernel, "num_reps", "rep")
|
||||
)
|
||||
or (
|
||||
"prune_configs_by" in defaults
|
||||
and defaults["prune_configs_by"].default
|
||||
!= kernel.early_config_prune
|
||||
)
|
||||
# Set via reset_to_zero argument
|
||||
or len(kernel.reset_idx) != 0
|
||||
or len(kernel.restore_idx) != 0
|
||||
):
|
||||
self.raise_unsupported(
|
||||
"Only configs and keys are supported for triton.autotune"
|
||||
|
|
|
|||
|
|
@ -647,6 +647,12 @@ decompose_mem_bound_mm: bool = False
|
|||
# In the common case, most inputs will be aligned.
|
||||
assume_aligned_inputs: bool = False
|
||||
|
||||
# For the user-written Triton kernels compiled with the model, ignore the unsupported
|
||||
# arguments passed to the @triton.autotune in the user's code; this is unsafe, as
|
||||
# ignoring the unsupported args may lead to unexpected autotuning behavior: don't
|
||||
# set unless you know what you're doing.
|
||||
unsafe_ignore_unsupported_triton_autotune_args: bool = False
|
||||
|
||||
|
||||
# config specific to codegen/cpp.py
|
||||
class cpp:
|
||||
|
|
|
|||
|
|
@ -118,6 +118,36 @@ if has_triton():
|
|||
tmp2 = tmp0 + tmp1
|
||||
tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
|
||||
|
||||
def _dummy_early_config_prune(configs, *_, **__):
|
||||
return configs
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
||||
],
|
||||
key=[],
|
||||
warmup=10,
|
||||
rep=20,
|
||||
prune_configs_by={"early_config_prune": _dummy_early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def add_kernel_autotuned_with_unsupported_args(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.jit
|
||||
def add_kernel_with_scaling(
|
||||
in_ptr0,
|
||||
|
|
|
|||
Loading…
Reference in a new issue