mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Don't uselessly recompute axiom dict every static eval call (#138967)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/138967 Approved by: https://github.com/ezyang
This commit is contained in:
parent
c4d9428b17
commit
6a1c451479
3 changed files with 39 additions and 24 deletions
|
|
@ -1,66 +1,65 @@
|
|||
add_loop_eager, compile_time_instruction_count, 3004749893, 0.015
|
||||
add_loop_eager,compile_time_instruction_count,3027000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic, compile_time_instruction_count, 5563298740, 0.025
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5596000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor, compile_time_instruction_count, 24064639114, 0.015
|
||||
add_loop_inductor,compile_time_instruction_count,24260000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu, compile_time_instruction_count, 40992578178, 0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40380000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu, compile_time_instruction_count, 22822864522, 0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,23010000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager, compile_time_instruction_count, 1034818091, 0.015
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1028000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor, compile_time_instruction_count, 19049541914, 0.015
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19170000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad, compile_time_instruction_count, 15806042948, 0.015
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15810000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu, compile_time_instruction_count, 16403080126, 0.20
|
||||
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,16760000000,0.2
|
||||
|
||||
|
||||
|
||||
|
||||
update_hint_regression, compile_time_instruction_count, 1795333141, 0.02
|
||||
update_hint_regression,compile_time_instruction_count,1743000000,0.02
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression, compile_time_instruction_count, 1154135694, 0.015
|
||||
sum_floordiv_regression,compile_time_instruction_count,1160000000,0.015
|
||||
|
||||
|
||||
|
||||
symint_sum, compile_time_instruction_count, 3270576815, 0.015
|
||||
symint_sum,compile_time_instruction_count,3293000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu, compile_time_instruction_count, 1981730523, 0.015
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2001000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu, compile_time_instruction_count, 5711895807, 0.015
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5778000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu, compile_time_instruction_count, 8963708885 , 0.015
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8989000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu, compile_time_instruction_count, 3795666651, 0.015
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3822000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu, compile_time_instruction_count, 10175364418, 0.015
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10260000000,0.015
|
||||
|
|
|
|||
|
|
|
@ -10252,6 +10252,9 @@ ShapeEnv not equal: field values don't match:
|
|||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {0 < Mod(s0, 3): False, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, False: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False, True: True}
|
||||
> Right: {}
|
||||
==> divisible: values don't match.
|
||||
> Left: {Mod(s0, 3)}
|
||||
> Right: {}
|
||||
|
|
@ -10289,6 +10292,9 @@ ShapeEnv not equal: field values don't match:
|
|||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {False: False, True: True}
|
||||
> Right: {}
|
||||
==> guards: values don't match.
|
||||
> Left: [Eq(s0, 3)]
|
||||
> Right: []
|
||||
|
|
@ -10330,6 +10336,9 @@ ShapeEnv not equal: field values don't match:
|
|||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {3 <= s0: True, s0 < 3: False}
|
||||
> Right: {}
|
||||
==> guards: values don't match.
|
||||
> Left: [s0 >= 3]
|
||||
> Right: []
|
||||
|
|
@ -10362,6 +10371,9 @@ ShapeEnv not equal: field values don't match:
|
|||
"""\
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> axioms: values don't match.
|
||||
> Left: {0 < PythonMod(u0, 3): False, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, False: False, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) <= 0: True, True: True}
|
||||
> Right: {}
|
||||
==> deferred_runtime_asserts: values don't match.
|
||||
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
|
||||
> Right: {}
|
||||
|
|
|
|||
|
|
@ -3021,6 +3021,7 @@ class ShapeEnv:
|
|||
)
|
||||
|
||||
self.guards: List[ShapeGuard] = []
|
||||
self.axioms: Dict[sympy.Expr, sympy.Expr] = {}
|
||||
# Maps symbolic ints to their original concrete values
|
||||
# Currently populated from tensors
|
||||
self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {}
|
||||
|
|
@ -5347,15 +5348,16 @@ class ShapeEnv:
|
|||
expr = canonicalize_bool_expr(expr)
|
||||
|
||||
# Pattern matching
|
||||
symbols = tuple(expr.free_symbols)
|
||||
if axioms is None:
|
||||
axioms = self.get_axioms(symbols, compute_hint=compute_hint)
|
||||
subst = {}
|
||||
for e in axioms:
|
||||
if e.free_symbols.issubset(expr.free_symbols):
|
||||
subst.update(dict(self.get_implications(self.simplify(e))))
|
||||
subst = self.axioms
|
||||
else:
|
||||
subst = {}
|
||||
for e in axioms:
|
||||
if e.free_symbols.issubset(expr.free_symbols):
|
||||
subst.update(dict(self.get_implications(self.simplify(e))))
|
||||
|
||||
expr = expr.xreplace(subst)
|
||||
# TODO: compute hint might have gotten broken here
|
||||
|
||||
fs = expr.free_symbols
|
||||
|
||||
|
|
@ -6285,6 +6287,7 @@ class ShapeEnv:
|
|||
# or defer to runtime assert on.
|
||||
guard = ShapeGuard(g, self._get_sloc())
|
||||
self.guards.append(guard)
|
||||
self.axioms.update(dict(self.get_implications(self.simplify(g))))
|
||||
else:
|
||||
# it's fine to defer simple guards here without checking,
|
||||
# the _maybe_guard_rel() call above will set replacements if possible,
|
||||
|
|
@ -6406,6 +6409,7 @@ class ShapeEnv:
|
|||
# and the guard in question has no unbacked SymInts in front
|
||||
ix = cands[-1] if cands else None
|
||||
self.deferred_runtime_asserts.setdefault(ix, []).append(ra)
|
||||
self.axioms.update(dict(self.get_implications(self.simplify(expr))))
|
||||
self.num_deferred_runtime_asserts += 1
|
||||
self._update_version_counter()
|
||||
self._log_guard("runtime_assert", orig_expr, forcing_spec=False)
|
||||
|
|
|
|||
Loading…
Reference in a new issue