[inductor] Add kernel_code logging artifact (#126631)

This is useful for some compile errors where we don't finish outputting the full graph.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126631
Approved by: https://github.com/shunting314
This commit is contained in:
Jason Ansel 2024-05-20 16:36:25 -07:00 committed by PyTorch MergeBot
parent 4e921593a4
commit c08afbb3da
4 changed files with 19 additions and 4 deletions

View file

@ -108,6 +108,7 @@ else:
output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code")
LOCK_TIMEOUT = 600
@ -3109,6 +3110,7 @@ class AsyncCompile:
return cls.pool().submit(task)
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
kernel_code_log.info("Triton Kernel:\n%s", source_code)
_compile_start()
_set_triton_ptxas_path()
@ -3132,6 +3134,7 @@ class AsyncCompile:
return MultiKernelCall(*args, **kwargs)
def cpp(self, source_code: str):
kernel_code_log.info("CPP Kernel:\n%s", source_code)
if config.compile_threads <= 1:
return CppCodeCache.load(source_code).kernel
else:
@ -3139,6 +3142,7 @@ class AsyncCompile:
return LambdaFuture(lambda: get_result().kernel)
def cpp_pybinding(self, argtypes: List[str], source_code: str):
kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
if config.compile_threads <= 1:
return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
else:
@ -3148,6 +3152,8 @@ class AsyncCompile:
return LambdaFuture(get_result)
def cuda(self, source_code, dst_file_ext):
kernel_code_log.info("CUDA Kernel:\n%s", source_code)
def task():
return CUDACodeCache.load(source_code, dst_file_ext)[0]

View file

@ -428,14 +428,12 @@ class Loops(IRNode):
@cache_on_self
def inner_fn_opcount(self):
from .ir import FlexibleLayout
opcounter = OpCounterCSE(V.MockHandler())
with V.set_ops_handler(opcounter), patch.object(
FlexibleLayout, "allow_indexing", True
):
result = self.inner_fn(*self.inner_fn_args())
self.inner_fn(*self.inner_fn_args())
return opcounter.op_count
def inner_fn_args(self):

View file

@ -214,6 +214,7 @@ def set_logs(
trace_call: bool = False,
trace_bytecode: bool = False,
output_code: bool = False,
kernel_code: bool = False,
schedule: bool = False,
perf_hints: bool = False,
post_grad_graphs: bool = False,
@ -355,7 +356,10 @@ def set_logs(
traces bytecode. Default: ``False``
output_code (:class:`bool`):
Whether to emit the TorchInductor output code. Default: ``False``
Whether to emit the TorchInductor output code on a per-graph basis. Default: ``False``
kernel_code (:class:`bool`):
Whether to emit the TorchInductor output code on a per-kernel bases. Default: ``False``
schedule (:class:`bool`):
Whether to emit the TorchInductor schedule. Default: ``False``
@ -473,6 +477,7 @@ def set_logs(
trace_call=trace_call,
trace_bytecode=trace_bytecode,
output_code=output_code,
kernel_code=kernel_code,
schedule=schedule,
perf_hints=perf_hints,
post_grad_graphs=post_grad_graphs,

View file

@ -122,6 +122,12 @@ register_artifact(
off_by_default=True,
visible=True,
)
register_artifact(
"kernel_code",
"Prints the code that Inductor generates (on a per-kernel basis)",
off_by_default=True,
visible=True,
)
register_artifact(
"schedule",
"Inductor scheduler information. Useful if working on Inductor fusion algo",