[inductor] Simplify get_launch_args_* handling (#143835)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143835
Approved by: https://github.com/eellison, https://github.com/shunting314
ghstack dependencies: #143813, #143814, #143815, #143817, #143818
This commit is contained in:
Jason Ansel 2024-12-25 07:21:58 -08:00 committed by PyTorch MergeBot
parent 138efb3002
commit 6bdf2addc5

View file

@ -526,139 +526,68 @@ class CachingAutotuner(KernelInterface):
else binary.metadata
),
"shared": binary_shared,
"num_warps": (
binary.num_warps
if hasattr(binary, "num_warps")
else binary.metadata.num_warps
),
"cta_args": (
(
binary.num_ctas,
*get_first_attr(binary, "cluster_dims", "clusterDims"),
)
if hasattr(binary, "num_ctas")
else (
(binary.metadata.num_ctas, *binary.metadata.cluster_dims)
if hasattr(binary, "metadata")
else ()
)
),
"function": get_first_attr(binary, "function", "cu_function"),
"runner": get_first_attr(binary, "run", "c_wrapper"),
}
scope["num_warps"] = (
binary.num_warps
if hasattr(binary, "num_warps")
else binary.metadata.num_warps
)
scope["cta_args"] = (
(binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims"))
if hasattr(binary, "num_ctas")
else (
(binary.metadata.num_ctas, *binary.metadata.cluster_dims)
if hasattr(binary, "metadata")
else ()
)
)
scope["function"] = get_first_attr(binary, "function", "cu_function")
def get_launch_args_without_kernel_launch_metadata(
grid,
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
bin,
launch_enter_hook,
launch_exit_hook,
num_warps,
shared,
cta_args,
args,
):
"""
Construct launch args before CompiledKernel.launch_metadata is added.
"""
return (
grid_0,
grid_1,
grid_2,
num_warps,
*cta_args,
shared,
stream,
function,
launch_enter_hook,
launch_exit_hook,
metadata,
)
# Getting the kernel launch args is extremely perf-sensitive. Evaluating
# `bin.launch_metadata` is relatively expensive, and returns None unless a
# `launch_enter_hook` is installed. So if we don't have that hook installed,
# we want to burn None in to the launch args with zero overhead.
# See https://github.com/pytorch/pytorch/issues/123597
if binary.launch_enter_hook:
def get_launch_args_with_kernel_launch_metadata(
grid,
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
bin,
launch_enter_hook,
launch_exit_hook,
num_warps,
shared,
cta_args,
args,
):
"""
Construct launch args after CompiledKernel.launch_metadata is added
by https://github.com/openai/triton/pull/3492 .
"""
return (
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
bin.launch_metadata(grid, stream, *args),
launch_enter_hook,
launch_exit_hook,
if not hasattr(binary, "launch_metadata"):
# launch args before CompiledKernel.launch_metadata is added.
# TODO(jansel): delete this branch in mid-2025
runner_args = [
"grid_0",
"grid_1",
"grid_2",
"num_warps",
"*cta_args",
"shared",
"stream",
"function",
"launch_enter_hook",
"launch_exit_hook",
"metadata",
*call_args,
]
else: # args after CompiledKernel.launch_metadata: https://github.com/openai/triton/pull/3492
# Getting the kernel launch args is extremely perf-sensitive. Evaluating
# `bin.launch_metadata` is relatively expensive, and returns None unless a
# `launch_enter_hook` is installed. So if we don't have that hook installed,
# we want to burn None in to the launch args with zero overhead.
# See https://github.com/pytorch/pytorch/issues/123597
if binary.launch_enter_hook:
launch_metadata = (
f"bin.launch_metadata(grid, stream, {', '.join(call_args)})"
)
else:
def get_launch_args_with_kernel_launch_metadata(
grid,
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
bin,
launch_enter_hook,
launch_exit_hook,
num_warps,
shared,
cta_args,
args,
):
"""
Construct launch args after CompiledKernel.launch_metadata is added
by https://github.com/openai/triton/pull/3492 .
"""
return (
grid_0,
grid_1,
grid_2,
stream,
function,
metadata,
None,
launch_enter_hook,
launch_exit_hook,
)
scope["get_launch_args"] = (
get_launch_args_with_kernel_launch_metadata
if hasattr(binary, "launch_metadata")
else get_launch_args_without_kernel_launch_metadata
)
scope["runner"] = get_first_attr(binary, "run", "c_wrapper")
else:
launch_metadata = "None"
runner_args = [
"grid_0",
"grid_1",
"grid_2",
"stream",
"function",
"metadata",
launch_metadata,
"launch_enter_hook",
"launch_exit_hook",
*call_args,
]
exec(
f"""
@ -667,14 +596,7 @@ class CachingAutotuner(KernelInterface):
grid_0, grid_1, grid_2 = grid(grid_meta)
else:
grid_0, grid_1, grid_2 = grid
args = {', '.join(call_args)},
launch_args = get_launch_args(
grid, grid_0, grid_1, grid_2, stream, function,
metadata, bin, launch_enter_hook, launch_exit_hook,
num_warps, shared, cta_args, args
)
runner(*launch_args, *args)
runner({', '.join(runner_args)})
return bin
""".lstrip(),
scope,