From 39fc280dce82b4f3af9710afd5fc05df4d147dda Mon Sep 17 00:00:00 2001 From: eellison Date: Thu, 18 Apr 2024 21:38:18 -0700 Subject: [PATCH] Dont precompile already seen keys, limit epilogue choices (#122642) Two changes: - in epilogue benchmark fusion, only take top 6 choices. There were basically no choices taken after this in HF. - Share a single precompilation function among matmuls with same key. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122642 Approved by: https://github.com/shunting314 ghstack dependencies: #124030 --- test/hi.py | 0 test/inductor/test_max_autotune.py | 16 ++++++++++++++++ torch/_inductor/codecache.py | 2 +- torch/_inductor/config.py | 3 +++ torch/_inductor/scheduler.py | 6 ++++++ torch/_inductor/select_algorithm.py | 26 ++++++++++++++++++++++++-- 6 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 test/hi.py diff --git a/test/hi.py b/test/hi.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index beb1b22df83..bbcff4f87fc 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -447,6 +447,22 @@ class TestMaxAutotune(TestCase): fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) + @config.patch(autotune_local_cache=False, autotune_remote_cache=False) + def test_precompilations(self): + def fn(a, b, c): + a = (a @ b) @ c + a, b, c = (t.to(torch.float16) for t in [a, b, c]) + return (a @ b) @ c + + fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) + inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)] + + self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2) + + from torch._dynamo.utils import counters + + self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 2) + def test_cat_addmm(self): def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): return torch.cat( diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 465891fc7f3..a2b88cf4744 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -301,7 +301,7 @@ class PersistentCache(CacheBase): return hit if config.max_autotune or config.max_autotune_gemm: - local_cache = self.get_local_cache() + local_cache = self.get_local_cache() if config.autotune_local_cache else {} # check local cache first since it is data specific to the current machine if ( not check_cache(local_cache) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index f96e95ab5d5..6b7c9beec9d 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -306,6 +306,9 @@ benchmark_multi_templates = ( os.environ.get("TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES", "0") == "1" ) +# Take how many of the top triton kernels to benchmark epilogue +max_epilogue_benchmarked_choices = 3 + # how many nodes to allow into a single fusion max_fusion_size = 64 diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 049a77a4efe..e3f7f395147 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -1835,6 +1835,8 @@ class Scheduler: min_ms_fused = float("inf") ms_fused_choice = None + triton_choices = 0 + for choice, unfused_time in choice_timings.items(): if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase): continue @@ -1842,6 +1844,10 @@ class Scheduler: if unfused_time >= ms1 + ms2: continue + triton_choices += 1 + if triton_choices > config.max_epilogue_benchmarked_choices: + break + # TODO - parallel compile triton templates # TODO - should prune/skip choices that are not within certain % of best choice with node1.node.swap_as_triton_caller(choice): diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 3261909d2be..4272d5034d0 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -866,6 +866,15 @@ class ErrorFromChoice(RuntimeError): class AlgorithmSelectorCache(PersistentCache): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # the autotuning will get occur in the scheduler, so there is + # no guarantee that the first lowering for a given key will also be the + # first to benchmark it. share a single precompilation function for all lowerings + # of a particular key + self.precompile_cache: Dict[str, Callable[[], None]] = {} + def __call__( self, name, @@ -902,6 +911,8 @@ class AlgorithmSelectorCache(PersistentCache): def make_benchmark_fn(): return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns) + inputs_key = repr([self.key_of(x) for x in input_nodes]) + def precompile(choices) -> Callable[[], None]: def no_op(*args, **kwargs): return @@ -927,13 +938,19 @@ class AlgorithmSelectorCache(PersistentCache): timings = self.lookup( choices, name, - repr([self.key_of(x) for x in input_nodes]), + inputs_key, benchmark=None, ) if timings: return no_op + precompile_key = ( + f"{name}: {inputs_key} : {torch.get_float32_matmul_precision()}" + ) + if precompile_func := self.precompile_cache.get(precompile_key): + return precompile_func + log.info( "Multithreaded precompilation for %d choices using %d worker threads", len(choices), @@ -947,7 +964,9 @@ class AlgorithmSelectorCache(PersistentCache): timeout=precompilation_timeout_seconds, ) + @functools.lru_cache(None) def wait_on_futures(): + counters["inductor"]["select_algorithm_precompile"] += 1 try: iterator = iter(futures) while True: @@ -963,8 +982,11 @@ class AlgorithmSelectorCache(PersistentCache): ) except StopIteration: pass + executor.shutdown(wait=True) + self.precompile_cache[precompile_key] = wait_on_futures + return wait_on_futures def autotune(choices): @@ -985,7 +1007,7 @@ class AlgorithmSelectorCache(PersistentCache): timings = self.lookup( choices, name, - repr([self.key_of(x) for x in input_nodes]), + inputs_key, autotune, ) autotune_elapse = time.time() - autotune_start_ts