mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Simplify get_launch_args_* handling (#143835)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143835 Approved by: https://github.com/eellison, https://github.com/shunting314 ghstack dependencies: #143813, #143814, #143815, #143817, #143818
This commit is contained in:
parent
138efb3002
commit
6bdf2addc5
1 changed files with 60 additions and 138 deletions
|
|
@ -526,139 +526,68 @@ class CachingAutotuner(KernelInterface):
|
|||
else binary.metadata
|
||||
),
|
||||
"shared": binary_shared,
|
||||
"num_warps": (
|
||||
binary.num_warps
|
||||
if hasattr(binary, "num_warps")
|
||||
else binary.metadata.num_warps
|
||||
),
|
||||
"cta_args": (
|
||||
(
|
||||
binary.num_ctas,
|
||||
*get_first_attr(binary, "cluster_dims", "clusterDims"),
|
||||
)
|
||||
if hasattr(binary, "num_ctas")
|
||||
else (
|
||||
(binary.metadata.num_ctas, *binary.metadata.cluster_dims)
|
||||
if hasattr(binary, "metadata")
|
||||
else ()
|
||||
)
|
||||
),
|
||||
"function": get_first_attr(binary, "function", "cu_function"),
|
||||
"runner": get_first_attr(binary, "run", "c_wrapper"),
|
||||
}
|
||||
|
||||
scope["num_warps"] = (
|
||||
binary.num_warps
|
||||
if hasattr(binary, "num_warps")
|
||||
else binary.metadata.num_warps
|
||||
)
|
||||
|
||||
scope["cta_args"] = (
|
||||
(binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims"))
|
||||
if hasattr(binary, "num_ctas")
|
||||
else (
|
||||
(binary.metadata.num_ctas, *binary.metadata.cluster_dims)
|
||||
if hasattr(binary, "metadata")
|
||||
else ()
|
||||
)
|
||||
)
|
||||
|
||||
scope["function"] = get_first_attr(binary, "function", "cu_function")
|
||||
|
||||
def get_launch_args_without_kernel_launch_metadata(
|
||||
grid,
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
stream,
|
||||
function,
|
||||
metadata,
|
||||
bin,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
num_warps,
|
||||
shared,
|
||||
cta_args,
|
||||
args,
|
||||
):
|
||||
"""
|
||||
Construct launch args before CompiledKernel.launch_metadata is added.
|
||||
"""
|
||||
return (
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
num_warps,
|
||||
*cta_args,
|
||||
shared,
|
||||
stream,
|
||||
function,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
metadata,
|
||||
)
|
||||
|
||||
# Getting the kernel launch args is extremely perf-sensitive. Evaluating
|
||||
# `bin.launch_metadata` is relatively expensive, and returns None unless a
|
||||
# `launch_enter_hook` is installed. So if we don't have that hook installed,
|
||||
# we want to burn None in to the launch args with zero overhead.
|
||||
# See https://github.com/pytorch/pytorch/issues/123597
|
||||
if binary.launch_enter_hook:
|
||||
|
||||
def get_launch_args_with_kernel_launch_metadata(
|
||||
grid,
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
stream,
|
||||
function,
|
||||
metadata,
|
||||
bin,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
num_warps,
|
||||
shared,
|
||||
cta_args,
|
||||
args,
|
||||
):
|
||||
"""
|
||||
Construct launch args after CompiledKernel.launch_metadata is added
|
||||
by https://github.com/openai/triton/pull/3492 .
|
||||
"""
|
||||
return (
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
stream,
|
||||
function,
|
||||
metadata,
|
||||
bin.launch_metadata(grid, stream, *args),
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
if not hasattr(binary, "launch_metadata"):
|
||||
# launch args before CompiledKernel.launch_metadata is added.
|
||||
# TODO(jansel): delete this branch in mid-2025
|
||||
runner_args = [
|
||||
"grid_0",
|
||||
"grid_1",
|
||||
"grid_2",
|
||||
"num_warps",
|
||||
"*cta_args",
|
||||
"shared",
|
||||
"stream",
|
||||
"function",
|
||||
"launch_enter_hook",
|
||||
"launch_exit_hook",
|
||||
"metadata",
|
||||
*call_args,
|
||||
]
|
||||
else: # args after CompiledKernel.launch_metadata: https://github.com/openai/triton/pull/3492
|
||||
# Getting the kernel launch args is extremely perf-sensitive. Evaluating
|
||||
# `bin.launch_metadata` is relatively expensive, and returns None unless a
|
||||
# `launch_enter_hook` is installed. So if we don't have that hook installed,
|
||||
# we want to burn None in to the launch args with zero overhead.
|
||||
# See https://github.com/pytorch/pytorch/issues/123597
|
||||
if binary.launch_enter_hook:
|
||||
launch_metadata = (
|
||||
f"bin.launch_metadata(grid, stream, {', '.join(call_args)})"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def get_launch_args_with_kernel_launch_metadata(
|
||||
grid,
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
stream,
|
||||
function,
|
||||
metadata,
|
||||
bin,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
num_warps,
|
||||
shared,
|
||||
cta_args,
|
||||
args,
|
||||
):
|
||||
"""
|
||||
Construct launch args after CompiledKernel.launch_metadata is added
|
||||
by https://github.com/openai/triton/pull/3492 .
|
||||
"""
|
||||
return (
|
||||
grid_0,
|
||||
grid_1,
|
||||
grid_2,
|
||||
stream,
|
||||
function,
|
||||
metadata,
|
||||
None,
|
||||
launch_enter_hook,
|
||||
launch_exit_hook,
|
||||
)
|
||||
|
||||
scope["get_launch_args"] = (
|
||||
get_launch_args_with_kernel_launch_metadata
|
||||
if hasattr(binary, "launch_metadata")
|
||||
else get_launch_args_without_kernel_launch_metadata
|
||||
)
|
||||
|
||||
scope["runner"] = get_first_attr(binary, "run", "c_wrapper")
|
||||
else:
|
||||
launch_metadata = "None"
|
||||
runner_args = [
|
||||
"grid_0",
|
||||
"grid_1",
|
||||
"grid_2",
|
||||
"stream",
|
||||
"function",
|
||||
"metadata",
|
||||
launch_metadata,
|
||||
"launch_enter_hook",
|
||||
"launch_exit_hook",
|
||||
*call_args,
|
||||
]
|
||||
|
||||
exec(
|
||||
f"""
|
||||
|
|
@ -667,14 +596,7 @@ class CachingAutotuner(KernelInterface):
|
|||
grid_0, grid_1, grid_2 = grid(grid_meta)
|
||||
else:
|
||||
grid_0, grid_1, grid_2 = grid
|
||||
|
||||
args = {', '.join(call_args)},
|
||||
launch_args = get_launch_args(
|
||||
grid, grid_0, grid_1, grid_2, stream, function,
|
||||
metadata, bin, launch_enter_hook, launch_exit_hook,
|
||||
num_warps, shared, cta_args, args
|
||||
)
|
||||
runner(*launch_args, *args)
|
||||
runner({', '.join(runner_args)})
|
||||
return bin
|
||||
""".lstrip(),
|
||||
scope,
|
||||
|
|
|
|||
Loading…
Reference in a new issue