Dont precompile already seen keys, limit epilogue choices (#122642)

Two changes:
- in epilogue benchmark fusion, only take top 6 choices. There were basically no choices taken after this in HF.
- Share a single precompilation function among matmuls with same key.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122642
Approved by: https://github.com/shunting314
ghstack dependencies: #124030
This commit is contained in:
eellison 2024-04-18 21:38:18 -07:00 committed by PyTorch MergeBot
parent 7ae835eee4
commit 39fc280dce
6 changed files with 50 additions and 3 deletions

0
test/hi.py Normal file
View file

View file

@ -447,6 +447,22 @@ class TestMaxAutotune(TestCase):
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
@config.patch(autotune_local_cache=False, autotune_remote_cache=False)
def test_precompilations(self):
def fn(a, b, c):
a = (a @ b) @ c
a, b, c = (t.to(torch.float16) for t in [a, b, c])
return (a @ b) @ c
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)]
self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2)
from torch._dynamo.utils import counters
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 2)
def test_cat_addmm(self):
def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
return torch.cat(

View file

@ -301,7 +301,7 @@ class PersistentCache(CacheBase):
return hit
if config.max_autotune or config.max_autotune_gemm:
local_cache = self.get_local_cache()
local_cache = self.get_local_cache() if config.autotune_local_cache else {}
# check local cache first since it is data specific to the current machine
if (
not check_cache(local_cache)

View file

@ -306,6 +306,9 @@ benchmark_multi_templates = (
os.environ.get("TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES", "0") == "1"
)
# Take how many of the top triton kernels to benchmark epilogue
max_epilogue_benchmarked_choices = 3
# how many nodes to allow into a single fusion
max_fusion_size = 64

View file

@ -1835,6 +1835,8 @@ class Scheduler:
min_ms_fused = float("inf")
ms_fused_choice = None
triton_choices = 0
for choice, unfused_time in choice_timings.items():
if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase):
continue
@ -1842,6 +1844,10 @@ class Scheduler:
if unfused_time >= ms1 + ms2:
continue
triton_choices += 1
if triton_choices > config.max_epilogue_benchmarked_choices:
break
# TODO - parallel compile triton templates
# TODO - should prune/skip choices that are not within certain % of best choice
with node1.node.swap_as_triton_caller(choice):

View file

@ -866,6 +866,15 @@ class ErrorFromChoice(RuntimeError):
class AlgorithmSelectorCache(PersistentCache):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# the autotuning will get occur in the scheduler, so there is
# no guarantee that the first lowering for a given key will also be the
# first to benchmark it. share a single precompilation function for all lowerings
# of a particular key
self.precompile_cache: Dict[str, Callable[[], None]] = {}
def __call__(
self,
name,
@ -902,6 +911,8 @@ class AlgorithmSelectorCache(PersistentCache):
def make_benchmark_fn():
return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns)
inputs_key = repr([self.key_of(x) for x in input_nodes])
def precompile(choices) -> Callable[[], None]:
def no_op(*args, **kwargs):
return
@ -927,13 +938,19 @@ class AlgorithmSelectorCache(PersistentCache):
timings = self.lookup(
choices,
name,
repr([self.key_of(x) for x in input_nodes]),
inputs_key,
benchmark=None,
)
if timings:
return no_op
precompile_key = (
f"{name}: {inputs_key} : {torch.get_float32_matmul_precision()}"
)
if precompile_func := self.precompile_cache.get(precompile_key):
return precompile_func
log.info(
"Multithreaded precompilation for %d choices using %d worker threads",
len(choices),
@ -947,7 +964,9 @@ class AlgorithmSelectorCache(PersistentCache):
timeout=precompilation_timeout_seconds,
)
@functools.lru_cache(None)
def wait_on_futures():
counters["inductor"]["select_algorithm_precompile"] += 1
try:
iterator = iter(futures)
while True:
@ -963,8 +982,11 @@ class AlgorithmSelectorCache(PersistentCache):
)
except StopIteration:
pass
executor.shutdown(wait=True)
self.precompile_cache[precompile_key] = wait_on_futures
return wait_on_futures
def autotune(choices):
@ -985,7 +1007,7 @@ class AlgorithmSelectorCache(PersistentCache):
timings = self.lookup(
choices,
name,
repr([self.key_of(x) for x in input_nodes]),
inputs_key,
autotune,
)
autotune_elapse = time.time() - autotune_start_ts