Don't uselessly recompute axiom dict every static eval call (#138967)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138967
Approved by: https://github.com/ezyang
This commit is contained in:
Laith Sakka 2024-10-31 10:35:19 -07:00 committed by PyTorch MergeBot
parent c4d9428b17
commit 6a1c451479
3 changed files with 39 additions and 24 deletions

View file

@ -1,66 +1,65 @@
add_loop_eager, compile_time_instruction_count, 3004749893, 0.015
add_loop_eager,compile_time_instruction_count,3027000000,0.015
add_loop_eager_dynamic, compile_time_instruction_count, 5563298740, 0.025
add_loop_eager_dynamic,compile_time_instruction_count,5596000000,0.025
add_loop_inductor, compile_time_instruction_count, 24064639114, 0.015
add_loop_inductor,compile_time_instruction_count,24260000000,0.015
add_loop_inductor_dynamic_gpu, compile_time_instruction_count, 40992578178, 0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40380000000,0.025
add_loop_inductor_gpu, compile_time_instruction_count, 22822864522, 0.015
add_loop_inductor_gpu,compile_time_instruction_count,23010000000,0.015
basic_modules_ListOfLinears_eager, compile_time_instruction_count, 1034818091, 0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1028000000,0.015
basic_modules_ListOfLinears_inductor, compile_time_instruction_count, 19049541914, 0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19170000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad, compile_time_instruction_count, 15806042948, 0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15810000000,0.015
basic_modules_ListOfLinears_inductor_gpu, compile_time_instruction_count, 16403080126, 0.20
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,16760000000,0.2
update_hint_regression, compile_time_instruction_count, 1795333141, 0.02
update_hint_regression,compile_time_instruction_count,1743000000,0.02
sum_floordiv_regression, compile_time_instruction_count, 1154135694, 0.015
sum_floordiv_regression,compile_time_instruction_count,1160000000,0.015
symint_sum, compile_time_instruction_count, 3270576815, 0.015
symint_sum,compile_time_instruction_count,3293000000,0.015
aotdispatcher_inference_nosubclass_cpu, compile_time_instruction_count, 1981730523, 0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2001000000,0.015
aotdispatcher_inference_subclass_cpu, compile_time_instruction_count, 5711895807, 0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5778000000,0.015
aotdispatcher_partitioner_cpu, compile_time_instruction_count, 8963708885 , 0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8989000000,0.015
aotdispatcher_training_nosubclass_cpu, compile_time_instruction_count, 3795666651, 0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3822000000,0.015
aotdispatcher_training_subclass_cpu, compile_time_instruction_count, 10175364418, 0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10260000000,0.015

1 add_loop_eager compile_time_instruction_count 3004749893 3027000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 5563298740 5596000000 0.025
3 add_loop_inductor compile_time_instruction_count 24064639114 24260000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 40992578178 40380000000 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 22822864522 23010000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1034818091 1028000000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 19049541914 19170000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 15806042948 15810000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 16403080126 16760000000 0.20 0.2
10 update_hint_regression compile_time_instruction_count 1795333141 1743000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1154135694 1160000000 0.015
12 symint_sum compile_time_instruction_count 3270576815 3293000000 0.015
13 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 1981730523 2001000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5711895807 5778000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 8963708885 8989000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3795666651 3822000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10175364418 10260000000 0.015
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

View file

@ -10252,6 +10252,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {0 < Mod(s0, 3): False, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, False: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False, True: True}
> Right: {}
==> divisible: values don't match.
> Left: {Mod(s0, 3)}
> Right: {}
@ -10289,6 +10292,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {False: False, True: True}
> Right: {}
==> guards: values don't match.
> Left: [Eq(s0, 3)]
> Right: []
@ -10330,6 +10336,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {3 <= s0: True, s0 < 3: False}
> Right: {}
==> guards: values don't match.
> Left: [s0 >= 3]
> Right: []
@ -10362,6 +10371,9 @@ ShapeEnv not equal: field values don't match:
"""\
ShapeEnv not equal: field values don't match:
==> axioms: values don't match.
> Left: {0 < PythonMod(u0, 3): False, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, False: False, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) <= 0: True, True: True}
> Right: {}
==> deferred_runtime_asserts: values don't match.
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
> Right: {}

View file

@ -3021,6 +3021,7 @@ class ShapeEnv:
)
self.guards: List[ShapeGuard] = []
self.axioms: Dict[sympy.Expr, sympy.Expr] = {}
# Maps symbolic ints to their original concrete values
# Currently populated from tensors
self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
@ -5347,15 +5348,16 @@ class ShapeEnv:
expr = canonicalize_bool_expr(expr)
# Pattern matching
symbols = tuple(expr.free_symbols)
if axioms is None:
axioms = self.get_axioms(symbols, compute_hint=compute_hint)
subst = {}
for e in axioms:
if e.free_symbols.issubset(expr.free_symbols):
subst.update(dict(self.get_implications(self.simplify(e))))
subst = self.axioms
else:
subst = {}
for e in axioms:
if e.free_symbols.issubset(expr.free_symbols):
subst.update(dict(self.get_implications(self.simplify(e))))
expr = expr.xreplace(subst)
# TODO: compute hint might have gotten broken here
fs = expr.free_symbols
@ -6285,6 +6287,7 @@ class ShapeEnv:
# or defer to runtime assert on.
guard = ShapeGuard(g, self._get_sloc())
self.guards.append(guard)
self.axioms.update(dict(self.get_implications(self.simplify(g))))
else:
# it's fine to defer simple guards here without checking,
# the _maybe_guard_rel() call above will set replacements if possible,
@ -6406,6 +6409,7 @@ class ShapeEnv:
# and the guard in question has no unbacked SymInts in front
ix = cands[-1] if cands else None
self.deferred_runtime_asserts.setdefault(ix, []).append(ra)
self.axioms.update(dict(self.get_implications(self.simplify(expr))))
self.num_deferred_runtime_asserts += 1
self._update_version_counter()
self._log_guard("runtime_assert", orig_expr, forcing_spec=False)