mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[fx][inductor] Add statically_known_true utility for SymBool (#117359)
This adds a function `statically_known_true` for `SymBool` that works like inductor's `is_expr_static_and_true`. That is, it tries to simplify the expression to a constant or returns `False` if it cannot be simplified. This is useful in cases that can be optimized if the condition is met, otherwise it doesn't effect correctness so we can avoid adding guards. I also use this new function in inductor for `FakeTensorUpdater` and `remove_noop_pass` which both generated unexpected guards previously. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117359 Approved by: https://github.com/lezcano
This commit is contained in:
parent
661747c727
commit
001585f446
5 changed files with 73 additions and 10 deletions
|
|
@ -974,6 +974,7 @@ coverage_ignore_functions = [
|
|||
"is_symbolic",
|
||||
"parallel_and",
|
||||
"parallel_or",
|
||||
"statically_known_true",
|
||||
"sym_eq",
|
||||
"canonicalize_bool_expr",
|
||||
# torch.fx.experimental.unification.core
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
ShapeEnv,
|
||||
is_symbolic,
|
||||
StatelessSymbolicContext,
|
||||
statically_known_true,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
|
|
@ -633,6 +634,37 @@ class f(torch.nn.Module):
|
|||
getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None
|
||||
return (getitem, getitem_1)""") # noqa: B950
|
||||
|
||||
def test_statically_known_true(self):
|
||||
shape_env = ShapeEnv()
|
||||
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
|
||||
|
||||
# Statically known true
|
||||
self.assertTrue(statically_known_true(True))
|
||||
self.assertTrue(statically_known_true(s2 == s2))
|
||||
self.assertTrue(statically_known_true(s2 * s3 > s3))
|
||||
self.assertTrue(statically_known_true(s3 * s4 > s4))
|
||||
self.assertTrue(statically_known_true((s3 + s3) % 2 == 0))
|
||||
|
||||
# Statically known false
|
||||
self.assertFalse(statically_known_true(False))
|
||||
self.assertFalse(statically_known_true(s3 * s4 <= s4))
|
||||
self.assertFalse(statically_known_true((s3 + s3) % 2 == 1))
|
||||
|
||||
# True for hints, but not known statically
|
||||
self.assertFalse(statically_known_true(s2 + s2 == s4))
|
||||
self.assertFalse(statically_known_true(s4 % s2 == 0))
|
||||
self.assertFalse(statically_known_true(s2 != s3))
|
||||
self.assertFalse(statically_known_true(s3 * s4 > s2))
|
||||
|
||||
# False for hints, but not known statically
|
||||
self.assertFalse(statically_known_true(s2 == s3))
|
||||
self.assertFalse(statically_known_true(s2 > s3))
|
||||
self.assertFalse(statically_known_true(s3 + s3 == s4))
|
||||
|
||||
# No guards should be generated
|
||||
self.assertEqual(len(shape_env.guards), 0)
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
||||
class TestSymNumberMagicMethods(TestCase):
|
||||
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_fun
|
|||
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
|
||||
|
||||
from torch._utils_internal import print_graph
|
||||
from torch.fx.experimental.symbolic_shapes import definitely_true, sym_eq
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
||||
from torch.fx.immutable_collections import immutable_dict
|
||||
|
||||
from .. import config, inductor_prims, ir, pattern_matcher
|
||||
|
|
@ -502,13 +502,13 @@ def same_meta(node1: torch.fx.Node, node2: torch.fx.Node):
|
|||
return (
|
||||
val1 is not None
|
||||
and val2 is not None
|
||||
and definitely_true(sym_eq(val1.size(), val2.size()))
|
||||
and statically_known_true(sym_eq(val1.size(), val2.size()))
|
||||
and val1.layout == val2.layout
|
||||
and val1.dtype == val2.dtype
|
||||
and val1.device == val2.device
|
||||
and (
|
||||
val1.layout != torch.strided
|
||||
or definitely_true(sym_eq(val1.stride(), val2.stride()))
|
||||
or statically_known_true(sym_eq(val1.stride(), val2.stride()))
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
|
|||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._pytree import tree_map
|
||||
from .virtualized import V
|
||||
|
|
@ -78,6 +79,9 @@ class FakeTensorUpdater:
|
|||
for node in self.graph.nodes:
|
||||
existing_storages[get_node_storage(node)] += 1
|
||||
|
||||
def is_intlist_same(new, old):
|
||||
return statically_known_true(sym_eq(new, old))
|
||||
|
||||
def is_fake_tensor_same(new, old):
|
||||
if type(new) != type(old):
|
||||
return False
|
||||
|
|
@ -88,10 +92,16 @@ class FakeTensorUpdater:
|
|||
is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
|
||||
)
|
||||
assert isinstance(new, torch.Tensor)
|
||||
if new.shape != old.shape or new.layout != old.layout:
|
||||
if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
|
||||
return False
|
||||
if new.layout == torch.strided and new.stride() != old.stride():
|
||||
if new.layout == torch.strided and (
|
||||
not is_intlist_same(new.stride(), old.stride())
|
||||
or not statically_known_true(
|
||||
new.storage_offset() == old.storage_offset()
|
||||
)
|
||||
):
|
||||
return False
|
||||
|
||||
if get_storage(new) == get_storage(old):
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ __all__ = [
|
|||
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
|
||||
"is_concrete_bool", "is_singleton", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
|
||||
"has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext",
|
||||
"StatefulSymbolicContext", "SubclassSymbolicContext"
|
||||
"StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
|
||||
]
|
||||
|
||||
# FX node metadata keys for symbolic shape FX graph.
|
||||
|
|
@ -336,16 +336,34 @@ def definitely_false(a):
|
|||
return False
|
||||
return not bool(a)
|
||||
|
||||
# TODO: could improve parallel_or/parallel_and by avoiding guards
|
||||
# if there exists a quantity that can be handled un-guardedly. However,
|
||||
# for backed SymInts, avoiding guards doesn't really matter in practice,
|
||||
# so I chose not to do it.
|
||||
def statically_known_true(x: Union[bool, SymBool]) -> bool:
|
||||
"""Returns True if x can be simplified to a constant and is true.
|
||||
|
||||
NOTE: This function doesn't introduce new guards, so the expression may end
|
||||
up evaluating to true at runtime even if this function returns False.
|
||||
|
||||
"""
|
||||
if isinstance(x, SymBool):
|
||||
expr = x.node.expr
|
||||
shape_env = x.node.shape_env
|
||||
try:
|
||||
simplified = shape_env._maybe_evaluate_static(expr)
|
||||
if simplified is not None:
|
||||
return bool(simplified)
|
||||
except Exception:
|
||||
log.debug("Could not simplify %s", expr)
|
||||
return False
|
||||
assert isinstance(x, bool)
|
||||
return x
|
||||
|
||||
|
||||
def parallel_or(*args):
|
||||
"""
|
||||
Evaluate the logical OR of several arguments, avoiding guarding on
|
||||
unbacked SymInts if another argument is definitely True.
|
||||
"""
|
||||
if any(statically_known_true(a) for a in args):
|
||||
return True
|
||||
if any(definitely_true(a) for a in args):
|
||||
return True
|
||||
return any(args)
|
||||
|
|
@ -355,6 +373,8 @@ def parallel_and(*args):
|
|||
Evaluate the logical FALSE of several arguments, avoiding guarding on
|
||||
unbacked SymInts if another argument is definitely False.
|
||||
"""
|
||||
if any(statically_known_true(torch.sym_not(a)) for a in args):
|
||||
return False
|
||||
if any(definitely_false(a) for a in args):
|
||||
return False
|
||||
return all(args)
|
||||
|
|
|
|||
Loading…
Reference in a new issue