From 6bdf2addc5e83cdf67af1261c417b3c91104ad02 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Wed, 25 Dec 2024 07:21:58 -0800 Subject: [PATCH] [inductor] Simplify get_launch_args_* handling (#143835) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143835 Approved by: https://github.com/eellison, https://github.com/shunting314 ghstack dependencies: #143813, #143814, #143815, #143817, #143818 --- torch/_inductor/runtime/triton_heuristics.py | 198 ++++++------------- 1 file changed, 60 insertions(+), 138 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index eca54fe913d..14d26e6c5fd 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -526,139 +526,68 @@ class CachingAutotuner(KernelInterface): else binary.metadata ), "shared": binary_shared, + "num_warps": ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ), + "cta_args": ( + ( + binary.num_ctas, + *get_first_attr(binary, "cluster_dims", "clusterDims"), + ) + if hasattr(binary, "num_ctas") + else ( + (binary.metadata.num_ctas, *binary.metadata.cluster_dims) + if hasattr(binary, "metadata") + else () + ) + ), + "function": get_first_attr(binary, "function", "cu_function"), + "runner": get_first_attr(binary, "run", "c_wrapper"), } - scope["num_warps"] = ( - binary.num_warps - if hasattr(binary, "num_warps") - else binary.metadata.num_warps - ) - - scope["cta_args"] = ( - (binary.num_ctas, *get_first_attr(binary, "cluster_dims", "clusterDims")) - if hasattr(binary, "num_ctas") - else ( - (binary.metadata.num_ctas, *binary.metadata.cluster_dims) - if hasattr(binary, "metadata") - else () - ) - ) - - scope["function"] = get_first_attr(binary, "function", "cu_function") - - def get_launch_args_without_kernel_launch_metadata( - grid, - grid_0, - grid_1, - grid_2, - stream, - function, - metadata, - bin, - launch_enter_hook, - launch_exit_hook, - num_warps, - shared, - cta_args, - args, - ): - """ - Construct launch args before CompiledKernel.launch_metadata is added. - """ - return ( - grid_0, - grid_1, - grid_2, - num_warps, - *cta_args, - shared, - stream, - function, - launch_enter_hook, - launch_exit_hook, - metadata, - ) - - # Getting the kernel launch args is extremely perf-sensitive. Evaluating - # `bin.launch_metadata` is relatively expensive, and returns None unless a - # `launch_enter_hook` is installed. So if we don't have that hook installed, - # we want to burn None in to the launch args with zero overhead. - # See https://github.com/pytorch/pytorch/issues/123597 - if binary.launch_enter_hook: - - def get_launch_args_with_kernel_launch_metadata( - grid, - grid_0, - grid_1, - grid_2, - stream, - function, - metadata, - bin, - launch_enter_hook, - launch_exit_hook, - num_warps, - shared, - cta_args, - args, - ): - """ - Construct launch args after CompiledKernel.launch_metadata is added - by https://github.com/openai/triton/pull/3492 . - """ - return ( - grid_0, - grid_1, - grid_2, - stream, - function, - metadata, - bin.launch_metadata(grid, stream, *args), - launch_enter_hook, - launch_exit_hook, + if not hasattr(binary, "launch_metadata"): + # launch args before CompiledKernel.launch_metadata is added. + # TODO(jansel): delete this branch in mid-2025 + runner_args = [ + "grid_0", + "grid_1", + "grid_2", + "num_warps", + "*cta_args", + "shared", + "stream", + "function", + "launch_enter_hook", + "launch_exit_hook", + "metadata", + *call_args, + ] + else: # args after CompiledKernel.launch_metadata: https://github.com/openai/triton/pull/3492 + # Getting the kernel launch args is extremely perf-sensitive. Evaluating + # `bin.launch_metadata` is relatively expensive, and returns None unless a + # `launch_enter_hook` is installed. So if we don't have that hook installed, + # we want to burn None in to the launch args with zero overhead. + # See https://github.com/pytorch/pytorch/issues/123597 + if binary.launch_enter_hook: + launch_metadata = ( + f"bin.launch_metadata(grid, stream, {', '.join(call_args)})" ) - - else: - - def get_launch_args_with_kernel_launch_metadata( - grid, - grid_0, - grid_1, - grid_2, - stream, - function, - metadata, - bin, - launch_enter_hook, - launch_exit_hook, - num_warps, - shared, - cta_args, - args, - ): - """ - Construct launch args after CompiledKernel.launch_metadata is added - by https://github.com/openai/triton/pull/3492 . - """ - return ( - grid_0, - grid_1, - grid_2, - stream, - function, - metadata, - None, - launch_enter_hook, - launch_exit_hook, - ) - - scope["get_launch_args"] = ( - get_launch_args_with_kernel_launch_metadata - if hasattr(binary, "launch_metadata") - else get_launch_args_without_kernel_launch_metadata - ) - - scope["runner"] = get_first_attr(binary, "run", "c_wrapper") + else: + launch_metadata = "None" + runner_args = [ + "grid_0", + "grid_1", + "grid_2", + "stream", + "function", + "metadata", + launch_metadata, + "launch_enter_hook", + "launch_exit_hook", + *call_args, + ] exec( f""" @@ -667,14 +596,7 @@ class CachingAutotuner(KernelInterface): grid_0, grid_1, grid_2 = grid(grid_meta) else: grid_0, grid_1, grid_2 = grid - - args = {', '.join(call_args)}, - launch_args = get_launch_args( - grid, grid_0, grid_1, grid_2, stream, function, - metadata, bin, launch_enter_hook, launch_exit_hook, - num_warps, shared, cta_args, args - ) - runner(*launch_args, *args) + runner({', '.join(runner_args)}) return bin """.lstrip(), scope,