Revert "Introduce torch.sym_sum (#136429)"

This reverts commit 90bed32b98.

Reverted https://github.com/pytorch/pytorch/pull/136429 on behalf of https://github.com/ezyang due to fails internal stuff ([comment](https://github.com/pytorch/pytorch/pull/136429#issuecomment-2403335147))
This commit is contained in:
PyTorch MergeBot 2024-10-09 20:08:01 +00:00
parent 572f506f9c
commit 16a2c2cfd4
17 changed files with 18 additions and 255 deletions

View file

@ -1,44 +0,0 @@
import sys
from benchmark_base import BenchmarkBase
import torch
class Benchmark(BenchmarkBase):
N = 200
def name(self):
return "symint_sum"
def description(self):
return "see https://docs.google.com/document/d/11xJXl1etSmefUxPiVyk885e0Dl-4o7QwxYcPiMIo2iY/edit"
def _prepare_once(self):
torch._dynamo.config.capture_scalar_outputs = True
torch.manual_seed(0)
self.splits = torch.randint(10, (self.N,))
def _prepare(self):
torch._dynamo.reset()
def _work(self):
@torch.compile(fullgraph=True)
def f(a):
xs = a.tolist()
y = sum(xs)
return torch.tensor(y)
f(self.splits)
def main():
result_path = sys.argv[1]
Benchmark().enable_compile_time_instruction_count().collect_all().append_results(
result_path
)
if __name__ == "__main__":
main()

View file

@ -731,7 +731,6 @@ Symbolic Numbers
sym_min
sym_not
sym_ite
sym_sum
Export Path
-------------

View file

@ -253,20 +253,6 @@ class TestInductorDynamic(TestCase):
opt_r = opt_f()
self.assertEqual(r, opt_r)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_sym_sum_unbacked(self, device):
def f(a):
xs = a.tolist()
y = sum(xs)
return torch.tensor(y)
splits = torch.randint(10, (100,), device=device)
opt_f = torch.compile(f, fullgraph=True)
r = f(splits)
opt_r = opt_f(splits)
self.assertEqual(r, opt_r)
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
def test_nonzero_size_factory_nobreak(self, device):
def f(x, b):

View file

@ -451,15 +451,6 @@ class TestPySymInt(TestCase):
self.assertEqual(guard_int(a0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
def test_sym_sum(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 2)
s1 = create_symint(shape_env, 3)
s2 = create_symint(shape_env, 4)
self.assertEqual(
(s0 + s1 + s2).node.expr, torch.sym_sum([s0, s1, s2]).node.expr
)
def test_prefer_deferred_runtime_assertions_over_guards(self):
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2)

View file

@ -691,10 +691,7 @@ def generate_tensor_like_override_tests(cls):
f"Unsupported argument type {arg_type} for {arg_name} of function {func}"
)
# Special case; this doesn't have a schema but takes a list
if func is torch.sym_sum:
func_args.append([TensorLike(), TensorLike()])
elif func in annotated_args:
if func in annotated_args:
for arg in annotated_args[func]:
# Guess valid input to aten function based on type of argument
t = arg["simple_type"]

View file

@ -133,7 +133,6 @@ __all__ = [
"sym_max",
"sym_min",
"sym_not",
"sym_sum",
"typename",
"unravel_index",
"use_deterministic_algorithms",
@ -847,28 +846,6 @@ def sym_min(a, b):
return builtins.min(a, b)
def sym_sum(args):
"""
N-ary add which is faster to compute for long lists than iterated binary
addition. Only does something special for integers.
"""
if overrides.has_torch_function(args):
return overrides.handle_torch_function(sym_sum, args, args)
found = None
for a in args:
if not isinstance(a, (SymInt, builtins.int)):
return builtins.sum(args)
if isinstance(a, SymInt):
found = a.node
if found is None:
return builtins.sum(args)
from torch.fx.experimental.sym_node import to_node, wrap_node
return wrap_node(found.sym_sum(tuple(to_node(found, a) for a in args)))
# Drop in replacement for math.sqrt, math.sin, math.cos etc
def _get_sym_math_fn(name):
def fn(a):

View file

@ -1319,7 +1319,6 @@ class OutputGraph:
fx.GraphModule(root, self.graph),
self.shape_env,
name,
export=self.export,
)
# NB: deferred runtime asserts can keep graphargs live, so make sure
# those are inserted before pruning

View file

@ -1,6 +1,5 @@
# mypy: ignore-errors
import builtins
import collections
import functools
import inspect
@ -1013,36 +1012,6 @@ class PolyfilledFunctionVariable(VariableTracker):
)
return SourcelessBuilder.create(tx, result)
# Special case for sum on tuple/list of ints
if (
self.fn is builtins.sum
and len(args) == 1
and not kwargs
and isinstance(args[0], (variables.ListVariable, variables.TupleVariable))
and all(
(isinstance(x, variables.ConstantVariable) and isinstance(x.value, int))
or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int)
for x in args[0].items
)
):
return variables.SymNodeVariable.create(
tx,
tx.output.create_proxy(
"call_function",
torch.sym_sum,
(tuple(a.as_proxy() for a in args[0].items),),
{},
),
sym_num=torch.sym_sum(
[
x.value
if isinstance(x, variables.ConstantVariable)
else x.sym_num
for x in args[0].items
]
),
)
traceable_function_variable = SourcelessBuilder.create(tx, self.traceable_fn)
return traceable_function_variable.call_function(tx, args, kwargs)

View file

@ -6207,11 +6207,6 @@ for method, func in magic_methods.items():
register_lowering(method_to_operator(method))(func)
@register_lowering(torch.sym_sum)
def sym_sum(args):
return sympy.Add(*args)
@register_lowering(aten._foobar)
def foobar(self, *args, **kwargs):
raise NotImplementedError("Helpful for debugging")

View file

@ -720,12 +720,6 @@ def sym_constrain_range_for_size(size, min=None, max=None):
if isinstance(size, (SymFloat, SymBool)):
raise ValueError("Constraining SymFloat or Symbool is nyi")
if type(size) is int:
if min is not None:
torch._check(size >= min)
if max is not None:
torch._check(size <= max)
return
_constrain_range_for_size(size, min=min, max=max)

View file

@ -1356,24 +1356,12 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
def _compute_proxy(
self, func: OpOverload, args: Tuple[object, ...], out: PySymType
) -> Proxy:
# Handle torch.sym_sum
n_args: Tuple[object, ...]
if len(args) == 1 and isinstance(args[0], (list, tuple)):
n_args = (
tuple(
get_proxy_slot(a, self.tracer).force().node
if isinstance(a, py_sym_types)
else a
for a in args[0]
),
)
else:
n_args = tuple(
get_proxy_slot(a, self.tracer).force().node
if isinstance(a, py_sym_types)
else a
for a in args
)
n_args = tuple(
get_proxy_slot(a, self.tracer).force().node
if isinstance(a, py_sym_types)
else a
for a in args
)
# func doesn't have a __torch_function__ that Proxy can interpose, so
# we gotta do it manually

View file

@ -422,47 +422,6 @@ class SymNode:
def int_(self):
return self.guard_int("", 0) # NB: uses Python backtrace
# This one is currently done by hand, but if we add other variadic
# functions consider factoring it out to be metaprogrammed too. Note that
# some load bearing logic is directly in torch.sym_sum
def sym_sum(self, args) -> "SymNode":
import sympy
# Inner impl
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
if get_proxy_mode():
return to_node(
self,
handle_sym_dispatch(
torch.sym_sum,
(tuple(wrap_node(a) for a in args),),
{},
),
)
exprs = [a.expr for a in args]
out = sympy.Add(*exprs)
size_hints = []
out_hint = None
for a in args:
if a.hint is None:
break
size_hints.append(a.hint)
else:
out_hint = sum(size_hints)
fx_node, _ = self.shape_env._create_fx_call_function(
torch.sym_sum, (tuple(a.fx_node for a in args),)
)
# NB: Only for integers!
return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node)
# You can manually trigger a guard with this function
def guard_int(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a

View file

@ -163,9 +163,6 @@ try:
def to_int(x: z3.ArithRef) -> z3.ArithRef:
return x if x.is_int() else z3.ToInt(x)
def sym_sum(self, args: z3.ArithRef) -> z3.ArithRef:
return sum(args)
# Implements Python division semantics.
def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
self.validator.add_assertion(denominator != 0) # type: ignore[arg-type]
@ -268,10 +265,7 @@ try:
@functools.wraps(func)
def wrapper(*args):
# Lifts the arguments into a list of Z3 inhabitants.
if len(args) == 1 and isinstance(args[0], (list, tuple)):
wrapped_args = (tuple(wrap(a) for a in args[0]),)
else:
wrapped_args = tuple(wrap(a) for a in args)
wrapped_args = (wrap(a) for a in args)
# Run the function on the Z3 expressions.
return func(*wrapped_args)
@ -295,7 +289,6 @@ try:
torch.sym_float: lift(ops.to_real),
torch.sym_max: lift(ops.max),
torch.sym_min: lift(ops.min),
torch.sym_sum: lift(ops.sym_sum),
torch.sym_ite: lift(lambda b, t, f: t if b else f),
torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined]
# Not lifted because we only use this function as a

View file

@ -105,16 +105,12 @@ def insert_deferred_runtime_asserts(
resolve_unbacked_bindings,
)
from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.reference import (
OptimizedPythonReferenceAnalysis,
PythonReferenceAnalysis,
)
from torch.utils._sympy.reference import PythonReferenceAnalysis
from torch.utils._sympy.value_ranges import ValueRanges
# TODO: Request simplification on runtime asserts before emitting them
ras_by_symbol = shape_env.deferred_runtime_asserts.copy()
graph = gm.graph
tracer = fx.proxy.GraphAppendingTracer(graph)
graph_code_log.debug(
"%s",
lazy_format_graph_code(
@ -165,12 +161,10 @@ def insert_deferred_runtime_asserts(
stack_trace: Optional[str] = None,
nn_module_stack: Optional[Dict[str, Any]] = None,
) -> None:
fake_args = pytree.tree_map(
lambda arg: _get_example_value(arg)
if isinstance(arg, torch.fx.Node)
else arg,
node.args,
)
fake_args = [
_get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg
for arg in node.args
]
try:
node.meta[val_key] = node.target(*fake_args) # type: ignore[operator]
except NotImplementedError:
@ -187,8 +181,6 @@ def insert_deferred_runtime_asserts(
added_asserts: Set[sympy.Expr] = set()
constrained_unbacked_symbols: Set[sympy.Symbol] = set()
Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis
def _sympy_interp(expr_to_proxy, expr):
# sympy_interp() with hash consing
from sympy import Integer, Number, Symbol
@ -201,11 +193,11 @@ def insert_deferred_runtime_asserts(
return expr_to_proxy[expr]
# base cases, don't cache
if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)):
return sympy_interp(Analysis, expr_to_proxy, expr)
return sympy_interp(PythonReferenceAnalysis, expr_to_proxy, expr)
# hash cons on arguments, run expr handler
expr_to_proxy[expr] = _run_sympy_handler(
Analysis,
PythonReferenceAnalysis,
[_sympy_interp(expr_to_proxy, arg) for arg in expr.args],
expr,
)
@ -289,7 +281,7 @@ def insert_deferred_runtime_asserts(
and s not in expr_to_proxy
):
with _set_node_metadata_hook(gm, _node_metadata_hook):
expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer)
expr_to_proxy[s] = fx.Proxy(cb())
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
match_symbol(example_value, lambda: node)
@ -395,7 +387,7 @@ def insert_deferred_runtime_asserts(
elif sym_expr not in expr_to_proxy and not isinstance(
sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
): # don't hash cons primitives
expr_to_proxy[sym_expr] = fx.Proxy(node, tracer=tracer) # type: ignore[arg-type]
expr_to_proxy[sym_expr] = fx.Proxy(node) # type: ignore[arg-type]
# We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained,
# so calls before that are redundant.
@ -488,9 +480,7 @@ def insert_deferred_runtime_asserts(
if s not in expr_to_proxy:
with _set_node_metadata_hook(gm, _node_metadata_hook):
expr_to_proxy[s] = fx.Proxy(
go(node, keypath), tracer=tracer
)
expr_to_proxy[s] = fx.Proxy(go(node, keypath))
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
for i0 in defs:

View file

@ -1139,7 +1139,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.sym_min: lambda a, b: -1,
torch.sym_not: lambda input: -1,
torch.sym_ite: lambda a, b, c: -1,
torch.sym_sum: lambda args: -1,
torch._sym_sqrt: lambda input: -1,
torch._sym_cos: lambda input: -1,
torch._sym_cosh: lambda input: -1,

View file

@ -138,12 +138,6 @@ def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64):
if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
return getattr(analysis, handler_name)(*args, index_dtype)
# Fastpath for n-ary integral addition
if expr.func is sympy.Add and expr.is_integer and hasattr(analysis, "sym_sum"):
r = analysis.sym_sum(args)
log.debug("sym_sum(%s) -> %s", args, r)
return r
if hasattr(expr.func, "_torch_handler_name"):
handler_name = expr.func._torch_handler_name
else:

View file

@ -142,10 +142,6 @@ class ReferenceAnalysis:
def add(a, b):
return _keep_float(operator.add)(a, b)
@classmethod
def sym_sum(cls, args):
return sympy.Add(*args)
@staticmethod
def mul(a, b):
return _keep_float(operator.mul)(a, b)
@ -210,17 +206,6 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
def not_(a):
return torch.sym_not(a)
@classmethod
def sym_sum(cls, args):
if len(args) == 0:
return 0
if len(args) == 1:
return args[0]
acc = cls.add(args[0], args[1])
for i in range(2, len(args)):
acc = cls.add(acc, args[i])
return acc
@staticmethod
def floordiv(a, b):
return a // b
@ -299,14 +284,6 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
return round(a, ndigits=b)
# Like PythonReferenceAnalysis, but some export-unfriendly choices of
# operators to make things faster
class OptimizedPythonReferenceAnalysis(PythonReferenceAnalysis):
@staticmethod
def sym_sum(args):
return torch.sym_sum(args)
def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return torch.ops.prims.convert_element_type.default(x, dtype)