diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index bf67e5ff7be..bbc5710c8f3 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -40,13 +40,11 @@ from torch._dynamo.testing import ( skipIfPy312, ) from torch._dynamo.utils import ifdynstaticdefault -from torch._guards import CompileContext, CompileId from torch._inductor.aoti_eager import ( aoti_compile_with_persistent_cache, aoti_eager_cache_dir, load_aoti_eager_cache, ) -from torch._inductor.codecache import cpp_prefix_path from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext from torch._inductor.fx_passes import pad_mm from torch._inductor.test_case import TestCase as InductorTestCase @@ -54,7 +52,6 @@ from torch._inductor.utils import ( add_scheduler_init_hook, run_and_get_code, run_and_get_cpp_code, - run_and_get_kernels, run_and_get_triton_code, run_fw_bw_and_get_code, ) @@ -6218,99 +6215,6 @@ class CommonTemplate: (torch.arange(-1e-5, 1e-5, 1e-7).to(dtype=dtype),), ) - @patch.object(cpp_prefix_path, "cache_clear", lambda: None) - @config.patch(force_disable_caches=True) - @skip_if_cpp_wrapper("run_and_get_kernels issue") - def test_deterministic_codegen(self): - if "cpu" in str(self.device) and config.is_fbcode(): - raise unittest.SkipTest("cpp packaging is wacky in fbcode") - - @torch.compile(fullgraph=True) - def a(x): - return x.cos().sin().softmax(-1) - - @torch.compile(fullgraph=True) - def b(x): - return x.sin().cos().softmax(-1) - - @torch.compile(fullgraph=True) - def c(x): - return x.cos().sin().softmax(-1) - - x = torch.randn(16, 256, device=self.device) - _, (coda_a0,) = run_and_get_kernels(a, x) - _, (coda_b0,) = run_and_get_kernels(b, x) - _, (coda_c0,) = run_and_get_kernels(c, x) - self.assertEqual(coda_a0, coda_c0) - - # compile in a different order - torch.compiler.reset() - _, (coda_c1,) = run_and_get_kernels(c, x) - _, (coda_a1,) = run_and_get_kernels(a, x) - _, (coda_b1,) = run_and_get_kernels(b, x) - self.assertEqual(coda_a0, coda_a1) - self.assertEqual(coda_b0, coda_b1) - self.assertEqual(coda_c0, coda_c1) - - # force a different CompileId - torch.compiler.reset() - CompileContext_init = CompileContext.__init__ - with patch.object( - CompileContext, - "__init__", - lambda self, _: CompileContext_init(self, CompileId(999, 999)), - ): - _, (coda_a2,) = run_and_get_kernels(a, x) - _, (coda_c2,) = run_and_get_kernels(c, x) - _, (coda_b2,) = run_and_get_kernels(b, x) - self.assertEqual(coda_a0, coda_a2) - self.assertEqual(coda_b0, coda_b2) - self.assertEqual(coda_c0, coda_c2) - - @patch.object(cpp_prefix_path, "cache_clear", lambda: None) - @config.patch(force_disable_caches=True) - @skip_if_cpp_wrapper("run_and_get_kernels issue") - def test_deterministic_codegen_on_graph_break(self): - if "cpu" in str(self.device) and config.is_fbcode(): - raise unittest.SkipTest("cpp packaging is wacky in fbcode") - - def a(x): - return x.cos().sin().softmax(-1) - - @torch.compile() - def b(x): - x = a(x) - torch._dynamo.graph_break() - x = a(x) - return x - - x = torch.randn(16, 256, device=self.device) - _, (code0, code1) = run_and_get_kernels(b, x) - self.assertEqual(code0, code1) - - @patch.object(cpp_prefix_path, "cache_clear", lambda: None) - @config.patch(force_disable_caches=True) - @skip_if_cpp_wrapper("run_and_get_kernels issue") - def test_deterministic_codegen_with_suffix(self): - if "cpu" in str(self.device) and config.is_fbcode(): - raise unittest.SkipTest("cpp packaging is wacky in fbcode") - - @torch.compile(fullgraph=True) - def a(x): - return x.cos().sin().softmax(-1) - - @torch.compile(fullgraph=True) - def b(x, y): - x = x.cos().sin().softmax(-1) - x = torch.matmul(x, y) - return x - - x = torch.randn(16, 256, device=self.device) - y = torch.randn(256, 256, device=self.device) - _, (code0,) = run_and_get_kernels(a, x) - _, (code1,) = run_and_get_kernels(b, x, y) - self.assertEqual(code0, code1) - def test_flip(self): def fn(x): return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2 diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index c94fd34d9eb..73492b7edf9 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -3069,6 +3069,7 @@ class TritonKernel(SIMDKernel): @staticmethod def inductor_meta_common(): + compile_id = torch._guards.CompileContext.current_compile_id() inductor_meta = { "backend_hash": torch.utils._triton.triton_hash_with_backend(), "are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(), @@ -3083,6 +3084,8 @@ class TritonKernel(SIMDKernel): "min_split_scan_rblock": config.triton.min_split_scan_rblock, "spill_threshold": config.triton.spill_threshold, "store_cubin": config.triton.store_cubin, + "compile_id": str(compile_id) if compile_id else None, + "is_forward": not V.graph.is_backward, } if torch.version.hip is not None: inductor_meta["is_hip"] = True diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 3a9afc4220c..14d26e6c5fd 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -18,6 +18,7 @@ import time from typing import Any, Container, Dict, List, Optional, Tuple import torch +from torch._guards import CompileId from torch.utils._ordered_set import OrderedSet from ..triton_bundler import TritonBundler @@ -762,6 +763,8 @@ class CachingAutotuner(KernelInterface): log_pt2_compile_event=True, metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us", + compile_id=CompileId.from_string(self.inductor_meta.get("compile_id")), + is_forward=self.inductor_meta.get("is_forward"), ): timings = { launcher: self.bench(launcher, *args, **kwargs) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index a979db5ba14..4cbe2026d30 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1477,14 +1477,6 @@ def run_and_get_code(fn, *args, **kwargs) -> Tuple[Any, List[str]]: return result, source_codes -def run_and_get_kernels(fn, *args, **kwargs) -> Tuple[Any, List[str]]: - result, source_codes = run_and_get_code(fn, *args, **kwargs) - kernels = [] - for code in source_codes: - kernels.extend(re.findall(r"'''.*?'''", code, re.DOTALL)) - return result, kernels - - def run_fw_bw_and_get_code(fn): def run_with_backward(): result = fn()