[fx][inductor] Add statically_known_true utility for SymBool (#117359)

This adds a function `statically_known_true` for `SymBool` that works
like inductor's `is_expr_static_and_true`. That is, it tries to simplify the
expression to a constant or returns `False` if it cannot be simplified.

This is useful in cases that can be optimized if the condition is met,
otherwise it doesn't effect correctness so we can avoid adding guards.

I also use this new function in inductor for `FakeTensorUpdater` and
`remove_noop_pass` which both generated unexpected guards previously.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117359
Approved by: https://github.com/lezcano
This commit is contained in:
Peter Bell 2024-01-15 14:07:46 +00:00 committed by PyTorch MergeBot
parent 661747c727
commit 001585f446
5 changed files with 73 additions and 10 deletions

View file

@ -974,6 +974,7 @@ coverage_ignore_functions = [
"is_symbolic",
"parallel_and",
"parallel_or",
"statically_known_true",
"sym_eq",
"canonicalize_bool_expr",
# torch.fx.experimental.unification.core

View file

@ -28,6 +28,7 @@ from torch.fx.experimental.symbolic_shapes import (
ShapeEnv,
is_symbolic,
StatelessSymbolicContext,
statically_known_true,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -633,6 +634,37 @@ class f(torch.nn.Module):
getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""") # noqa: B950
def test_statically_known_true(self):
shape_env = ShapeEnv()
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
# Statically known true
self.assertTrue(statically_known_true(True))
self.assertTrue(statically_known_true(s2 == s2))
self.assertTrue(statically_known_true(s2 * s3 > s3))
self.assertTrue(statically_known_true(s3 * s4 > s4))
self.assertTrue(statically_known_true((s3 + s3) % 2 == 0))
# Statically known false
self.assertFalse(statically_known_true(False))
self.assertFalse(statically_known_true(s3 * s4 <= s4))
self.assertFalse(statically_known_true((s3 + s3) % 2 == 1))
# True for hints, but not known statically
self.assertFalse(statically_known_true(s2 + s2 == s4))
self.assertFalse(statically_known_true(s4 % s2 == 0))
self.assertFalse(statically_known_true(s2 != s3))
self.assertFalse(statically_known_true(s3 * s4 > s2))
# False for hints, but not known statically
self.assertFalse(statically_known_true(s2 == s3))
self.assertFalse(statically_known_true(s2 > s3))
self.assertFalse(statically_known_true(s3 + s3 == s4))
# No guards should be generated
self.assertEqual(len(shape_env.guards), 0)
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestSymNumberMagicMethods(TestCase):
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):

View file

@ -17,7 +17,7 @@ from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_fun
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
from torch._utils_internal import print_graph
from torch.fx.experimental.symbolic_shapes import definitely_true, sym_eq
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
from torch.fx.immutable_collections import immutable_dict
from .. import config, inductor_prims, ir, pattern_matcher
@ -502,13 +502,13 @@ def same_meta(node1: torch.fx.Node, node2: torch.fx.Node):
return (
val1 is not None
and val2 is not None
and definitely_true(sym_eq(val1.size(), val2.size()))
and statically_known_true(sym_eq(val1.size(), val2.size()))
and val1.layout == val2.layout
and val1.dtype == val2.dtype
and val1.device == val2.device
and (
val1.layout != torch.strided
or definitely_true(sym_eq(val1.stride(), val2.stride()))
or statically_known_true(sym_eq(val1.stride(), val2.stride()))
)
)

View file

@ -3,6 +3,7 @@ from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
import torch
import torch.fx
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
from torch.utils import _pytree as pytree
from torch.utils._pytree import tree_map
from .virtualized import V
@ -78,6 +79,9 @@ class FakeTensorUpdater:
for node in self.graph.nodes:
existing_storages[get_node_storage(node)] += 1
def is_intlist_same(new, old):
return statically_known_true(sym_eq(new, old))
def is_fake_tensor_same(new, old):
if type(new) != type(old):
return False
@ -88,10 +92,16 @@ class FakeTensorUpdater:
is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
)
assert isinstance(new, torch.Tensor)
if new.shape != old.shape or new.layout != old.layout:
if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
return False
if new.layout == torch.strided and new.stride() != old.stride():
if new.layout == torch.strided and (
not is_intlist_same(new.stride(), old.stride())
or not statically_known_true(
new.storage_offset() == old.storage_offset()
)
):
return False
if get_storage(new) == get_storage(old):
return True

View file

@ -81,7 +81,7 @@ __all__ = [
"hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node",
"is_concrete_bool", "is_singleton", "SHAPEENV_EVENT_KEY", "CURRENT_NODE_KEY",
"has_free_symbols", "sym_eq", "SymbolicContext", "StatelessSymbolicContext",
"StatefulSymbolicContext", "SubclassSymbolicContext"
"StatefulSymbolicContext", "SubclassSymbolicContext", "statically_known_true",
]
# FX node metadata keys for symbolic shape FX graph.
@ -336,16 +336,34 @@ def definitely_false(a):
return False
return not bool(a)
# TODO: could improve parallel_or/parallel_and by avoiding guards
# if there exists a quantity that can be handled un-guardedly. However,
# for backed SymInts, avoiding guards doesn't really matter in practice,
# so I chose not to do it.
def statically_known_true(x: Union[bool, SymBool]) -> bool:
"""Returns True if x can be simplified to a constant and is true.
NOTE: This function doesn't introduce new guards, so the expression may end
up evaluating to true at runtime even if this function returns False.
"""
if isinstance(x, SymBool):
expr = x.node.expr
shape_env = x.node.shape_env
try:
simplified = shape_env._maybe_evaluate_static(expr)
if simplified is not None:
return bool(simplified)
except Exception:
log.debug("Could not simplify %s", expr)
return False
assert isinstance(x, bool)
return x
def parallel_or(*args):
"""
Evaluate the logical OR of several arguments, avoiding guarding on
unbacked SymInts if another argument is definitely True.
"""
if any(statically_known_true(a) for a in args):
return True
if any(definitely_true(a) for a in args):
return True
return any(args)
@ -355,6 +373,8 @@ def parallel_and(*args):
Evaluate the logical FALSE of several arguments, avoiding guarding on
unbacked SymInts if another argument is definitely False.
"""
if any(statically_known_true(torch.sym_not(a)) for a in args):
return False
if any(definitely_false(a) for a in args):
return False
return all(args)