mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Dont precompile already seen keys, limit epilogue choices (#122642)
Two changes: - in epilogue benchmark fusion, only take top 6 choices. There were basically no choices taken after this in HF. - Share a single precompilation function among matmuls with same key. Pull Request resolved: https://github.com/pytorch/pytorch/pull/122642 Approved by: https://github.com/shunting314 ghstack dependencies: #124030
This commit is contained in:
parent
7ae835eee4
commit
39fc280dce
6 changed files with 50 additions and 3 deletions
0
test/hi.py
Normal file
0
test/hi.py
Normal file
|
|
@ -447,6 +447,22 @@ class TestMaxAutotune(TestCase):
|
|||
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0)
|
||||
|
||||
@config.patch(autotune_local_cache=False, autotune_remote_cache=False)
|
||||
def test_precompilations(self):
|
||||
def fn(a, b, c):
|
||||
a = (a @ b) @ c
|
||||
a, b, c = (t.to(torch.float16) for t in [a, b, c])
|
||||
return (a @ b) @ c
|
||||
|
||||
fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn)
|
||||
inputs = [torch.rand([256, 256], device="cuda") for _ in range(3)]
|
||||
|
||||
self.assertEqual(fn(*inputs), fn_c(*inputs), atol=1e-2, rtol=1e-2)
|
||||
|
||||
from torch._dynamo.utils import counters
|
||||
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 2)
|
||||
|
||||
def test_cat_addmm(self):
|
||||
def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
|
||||
return torch.cat(
|
||||
|
|
|
|||
|
|
@ -301,7 +301,7 @@ class PersistentCache(CacheBase):
|
|||
return hit
|
||||
|
||||
if config.max_autotune or config.max_autotune_gemm:
|
||||
local_cache = self.get_local_cache()
|
||||
local_cache = self.get_local_cache() if config.autotune_local_cache else {}
|
||||
# check local cache first since it is data specific to the current machine
|
||||
if (
|
||||
not check_cache(local_cache)
|
||||
|
|
|
|||
|
|
@ -306,6 +306,9 @@ benchmark_multi_templates = (
|
|||
os.environ.get("TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES", "0") == "1"
|
||||
)
|
||||
|
||||
# Take how many of the top triton kernels to benchmark epilogue
|
||||
max_epilogue_benchmarked_choices = 3
|
||||
|
||||
# how many nodes to allow into a single fusion
|
||||
max_fusion_size = 64
|
||||
|
||||
|
|
|
|||
|
|
@ -1835,6 +1835,8 @@ class Scheduler:
|
|||
min_ms_fused = float("inf")
|
||||
ms_fused_choice = None
|
||||
|
||||
triton_choices = 0
|
||||
|
||||
for choice, unfused_time in choice_timings.items():
|
||||
if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase):
|
||||
continue
|
||||
|
|
@ -1842,6 +1844,10 @@ class Scheduler:
|
|||
if unfused_time >= ms1 + ms2:
|
||||
continue
|
||||
|
||||
triton_choices += 1
|
||||
if triton_choices > config.max_epilogue_benchmarked_choices:
|
||||
break
|
||||
|
||||
# TODO - parallel compile triton templates
|
||||
# TODO - should prune/skip choices that are not within certain % of best choice
|
||||
with node1.node.swap_as_triton_caller(choice):
|
||||
|
|
|
|||
|
|
@ -866,6 +866,15 @@ class ErrorFromChoice(RuntimeError):
|
|||
|
||||
|
||||
class AlgorithmSelectorCache(PersistentCache):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# the autotuning will get occur in the scheduler, so there is
|
||||
# no guarantee that the first lowering for a given key will also be the
|
||||
# first to benchmark it. share a single precompilation function for all lowerings
|
||||
# of a particular key
|
||||
self.precompile_cache: Dict[str, Callable[[], None]] = {}
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
name,
|
||||
|
|
@ -902,6 +911,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
def make_benchmark_fn():
|
||||
return self.make_benchmark_fn(choices, input_nodes, layout, input_gen_fns)
|
||||
|
||||
inputs_key = repr([self.key_of(x) for x in input_nodes])
|
||||
|
||||
def precompile(choices) -> Callable[[], None]:
|
||||
def no_op(*args, **kwargs):
|
||||
return
|
||||
|
|
@ -927,13 +938,19 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
timings = self.lookup(
|
||||
choices,
|
||||
name,
|
||||
repr([self.key_of(x) for x in input_nodes]),
|
||||
inputs_key,
|
||||
benchmark=None,
|
||||
)
|
||||
|
||||
if timings:
|
||||
return no_op
|
||||
|
||||
precompile_key = (
|
||||
f"{name}: {inputs_key} : {torch.get_float32_matmul_precision()}"
|
||||
)
|
||||
if precompile_func := self.precompile_cache.get(precompile_key):
|
||||
return precompile_func
|
||||
|
||||
log.info(
|
||||
"Multithreaded precompilation for %d choices using %d worker threads",
|
||||
len(choices),
|
||||
|
|
@ -947,7 +964,9 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
timeout=precompilation_timeout_seconds,
|
||||
)
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def wait_on_futures():
|
||||
counters["inductor"]["select_algorithm_precompile"] += 1
|
||||
try:
|
||||
iterator = iter(futures)
|
||||
while True:
|
||||
|
|
@ -963,8 +982,11 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
executor.shutdown(wait=True)
|
||||
|
||||
self.precompile_cache[precompile_key] = wait_on_futures
|
||||
|
||||
return wait_on_futures
|
||||
|
||||
def autotune(choices):
|
||||
|
|
@ -985,7 +1007,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
timings = self.lookup(
|
||||
choices,
|
||||
name,
|
||||
repr([self.key_of(x) for x in input_nodes]),
|
||||
inputs_key,
|
||||
autotune,
|
||||
)
|
||||
autotune_elapse = time.time() - autotune_start_ts
|
||||
|
|
|
|||
Loading…
Reference in a new issue