mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Introduce torch.sym_sum (#136429)"
This reverts commit 90bed32b98.
Reverted https://github.com/pytorch/pytorch/pull/136429 on behalf of https://github.com/ezyang due to fails internal stuff ([comment](https://github.com/pytorch/pytorch/pull/136429#issuecomment-2403335147))
This commit is contained in:
parent
572f506f9c
commit
16a2c2cfd4
17 changed files with 18 additions and 255 deletions
|
|
@ -1,44 +0,0 @@
|
|||
import sys
|
||||
|
||||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Benchmark(BenchmarkBase):
|
||||
N = 200
|
||||
|
||||
def name(self):
|
||||
return "symint_sum"
|
||||
|
||||
def description(self):
|
||||
return "see https://docs.google.com/document/d/11xJXl1etSmefUxPiVyk885e0Dl-4o7QwxYcPiMIo2iY/edit"
|
||||
|
||||
def _prepare_once(self):
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
torch.manual_seed(0)
|
||||
|
||||
self.splits = torch.randint(10, (self.N,))
|
||||
|
||||
def _prepare(self):
|
||||
torch._dynamo.reset()
|
||||
|
||||
def _work(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def f(a):
|
||||
xs = a.tolist()
|
||||
y = sum(xs)
|
||||
return torch.tensor(y)
|
||||
|
||||
f(self.splits)
|
||||
|
||||
|
||||
def main():
|
||||
result_path = sys.argv[1]
|
||||
Benchmark().enable_compile_time_instruction_count().collect_all().append_results(
|
||||
result_path
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -731,7 +731,6 @@ Symbolic Numbers
|
|||
sym_min
|
||||
sym_not
|
||||
sym_ite
|
||||
sym_sum
|
||||
|
||||
Export Path
|
||||
-------------
|
||||
|
|
|
|||
|
|
@ -253,20 +253,6 @@ class TestInductorDynamic(TestCase):
|
|||
opt_r = opt_f()
|
||||
self.assertEqual(r, opt_r)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_sym_sum_unbacked(self, device):
|
||||
def f(a):
|
||||
xs = a.tolist()
|
||||
y = sum(xs)
|
||||
return torch.tensor(y)
|
||||
|
||||
splits = torch.randint(10, (100,), device=device)
|
||||
|
||||
opt_f = torch.compile(f, fullgraph=True)
|
||||
r = f(splits)
|
||||
opt_r = opt_f(splits)
|
||||
self.assertEqual(r, opt_r)
|
||||
|
||||
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
||||
def test_nonzero_size_factory_nobreak(self, device):
|
||||
def f(x, b):
|
||||
|
|
|
|||
|
|
@ -451,15 +451,6 @@ class TestPySymInt(TestCase):
|
|||
self.assertEqual(guard_int(a0), 2)
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
|
||||
|
||||
def test_sym_sum(self):
|
||||
shape_env = ShapeEnv()
|
||||
s0 = create_symint(shape_env, 2)
|
||||
s1 = create_symint(shape_env, 3)
|
||||
s2 = create_symint(shape_env, 4)
|
||||
self.assertEqual(
|
||||
(s0 + s1 + s2).node.expr, torch.sym_sum([s0, s1, s2]).node.expr
|
||||
)
|
||||
|
||||
def test_prefer_deferred_runtime_assertions_over_guards(self):
|
||||
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
|
||||
s0 = create_symint(shape_env, 2)
|
||||
|
|
|
|||
|
|
@ -691,10 +691,7 @@ def generate_tensor_like_override_tests(cls):
|
|||
f"Unsupported argument type {arg_type} for {arg_name} of function {func}"
|
||||
)
|
||||
|
||||
# Special case; this doesn't have a schema but takes a list
|
||||
if func is torch.sym_sum:
|
||||
func_args.append([TensorLike(), TensorLike()])
|
||||
elif func in annotated_args:
|
||||
if func in annotated_args:
|
||||
for arg in annotated_args[func]:
|
||||
# Guess valid input to aten function based on type of argument
|
||||
t = arg["simple_type"]
|
||||
|
|
|
|||
|
|
@ -133,7 +133,6 @@ __all__ = [
|
|||
"sym_max",
|
||||
"sym_min",
|
||||
"sym_not",
|
||||
"sym_sum",
|
||||
"typename",
|
||||
"unravel_index",
|
||||
"use_deterministic_algorithms",
|
||||
|
|
@ -847,28 +846,6 @@ def sym_min(a, b):
|
|||
return builtins.min(a, b)
|
||||
|
||||
|
||||
def sym_sum(args):
|
||||
"""
|
||||
N-ary add which is faster to compute for long lists than iterated binary
|
||||
addition. Only does something special for integers.
|
||||
"""
|
||||
if overrides.has_torch_function(args):
|
||||
return overrides.handle_torch_function(sym_sum, args, args)
|
||||
|
||||
found = None
|
||||
for a in args:
|
||||
if not isinstance(a, (SymInt, builtins.int)):
|
||||
return builtins.sum(args)
|
||||
if isinstance(a, SymInt):
|
||||
found = a.node
|
||||
if found is None:
|
||||
return builtins.sum(args)
|
||||
|
||||
from torch.fx.experimental.sym_node import to_node, wrap_node
|
||||
|
||||
return wrap_node(found.sym_sum(tuple(to_node(found, a) for a in args)))
|
||||
|
||||
|
||||
# Drop in replacement for math.sqrt, math.sin, math.cos etc
|
||||
def _get_sym_math_fn(name):
|
||||
def fn(a):
|
||||
|
|
|
|||
|
|
@ -1319,7 +1319,6 @@ class OutputGraph:
|
|||
fx.GraphModule(root, self.graph),
|
||||
self.shape_env,
|
||||
name,
|
||||
export=self.export,
|
||||
)
|
||||
# NB: deferred runtime asserts can keep graphargs live, so make sure
|
||||
# those are inserted before pruning
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
import builtins
|
||||
import collections
|
||||
import functools
|
||||
import inspect
|
||||
|
|
@ -1013,36 +1012,6 @@ class PolyfilledFunctionVariable(VariableTracker):
|
|||
)
|
||||
return SourcelessBuilder.create(tx, result)
|
||||
|
||||
# Special case for sum on tuple/list of ints
|
||||
if (
|
||||
self.fn is builtins.sum
|
||||
and len(args) == 1
|
||||
and not kwargs
|
||||
and isinstance(args[0], (variables.ListVariable, variables.TupleVariable))
|
||||
and all(
|
||||
(isinstance(x, variables.ConstantVariable) and isinstance(x.value, int))
|
||||
or (isinstance(x, variables.SymNodeVariable) and x.python_type() is int)
|
||||
for x in args[0].items
|
||||
)
|
||||
):
|
||||
return variables.SymNodeVariable.create(
|
||||
tx,
|
||||
tx.output.create_proxy(
|
||||
"call_function",
|
||||
torch.sym_sum,
|
||||
(tuple(a.as_proxy() for a in args[0].items),),
|
||||
{},
|
||||
),
|
||||
sym_num=torch.sym_sum(
|
||||
[
|
||||
x.value
|
||||
if isinstance(x, variables.ConstantVariable)
|
||||
else x.sym_num
|
||||
for x in args[0].items
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
traceable_function_variable = SourcelessBuilder.create(tx, self.traceable_fn)
|
||||
return traceable_function_variable.call_function(tx, args, kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -6207,11 +6207,6 @@ for method, func in magic_methods.items():
|
|||
register_lowering(method_to_operator(method))(func)
|
||||
|
||||
|
||||
@register_lowering(torch.sym_sum)
|
||||
def sym_sum(args):
|
||||
return sympy.Add(*args)
|
||||
|
||||
|
||||
@register_lowering(aten._foobar)
|
||||
def foobar(self, *args, **kwargs):
|
||||
raise NotImplementedError("Helpful for debugging")
|
||||
|
|
|
|||
|
|
@ -720,12 +720,6 @@ def sym_constrain_range_for_size(size, min=None, max=None):
|
|||
|
||||
if isinstance(size, (SymFloat, SymBool)):
|
||||
raise ValueError("Constraining SymFloat or Symbool is nyi")
|
||||
if type(size) is int:
|
||||
if min is not None:
|
||||
torch._check(size >= min)
|
||||
if max is not None:
|
||||
torch._check(size <= max)
|
||||
return
|
||||
_constrain_range_for_size(size, min=min, max=max)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1356,24 +1356,12 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
def _compute_proxy(
|
||||
self, func: OpOverload, args: Tuple[object, ...], out: PySymType
|
||||
) -> Proxy:
|
||||
# Handle torch.sym_sum
|
||||
n_args: Tuple[object, ...]
|
||||
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
||||
n_args = (
|
||||
tuple(
|
||||
get_proxy_slot(a, self.tracer).force().node
|
||||
if isinstance(a, py_sym_types)
|
||||
else a
|
||||
for a in args[0]
|
||||
),
|
||||
)
|
||||
else:
|
||||
n_args = tuple(
|
||||
get_proxy_slot(a, self.tracer).force().node
|
||||
if isinstance(a, py_sym_types)
|
||||
else a
|
||||
for a in args
|
||||
)
|
||||
n_args = tuple(
|
||||
get_proxy_slot(a, self.tracer).force().node
|
||||
if isinstance(a, py_sym_types)
|
||||
else a
|
||||
for a in args
|
||||
)
|
||||
|
||||
# func doesn't have a __torch_function__ that Proxy can interpose, so
|
||||
# we gotta do it manually
|
||||
|
|
|
|||
|
|
@ -422,47 +422,6 @@ class SymNode:
|
|||
def int_(self):
|
||||
return self.guard_int("", 0) # NB: uses Python backtrace
|
||||
|
||||
# This one is currently done by hand, but if we add other variadic
|
||||
# functions consider factoring it out to be metaprogrammed too. Note that
|
||||
# some load bearing logic is directly in torch.sym_sum
|
||||
|
||||
def sym_sum(self, args) -> "SymNode":
|
||||
import sympy
|
||||
|
||||
# Inner impl
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_mode,
|
||||
handle_sym_dispatch,
|
||||
)
|
||||
|
||||
if get_proxy_mode():
|
||||
return to_node(
|
||||
self,
|
||||
handle_sym_dispatch(
|
||||
torch.sym_sum,
|
||||
(tuple(wrap_node(a) for a in args),),
|
||||
{},
|
||||
),
|
||||
)
|
||||
exprs = [a.expr for a in args]
|
||||
out = sympy.Add(*exprs)
|
||||
|
||||
size_hints = []
|
||||
out_hint = None
|
||||
for a in args:
|
||||
if a.hint is None:
|
||||
break
|
||||
size_hints.append(a.hint)
|
||||
else:
|
||||
out_hint = sum(size_hints)
|
||||
|
||||
fx_node, _ = self.shape_env._create_fx_call_function(
|
||||
torch.sym_sum, (tuple(a.fx_node for a in args),)
|
||||
)
|
||||
|
||||
# NB: Only for integers!
|
||||
return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node)
|
||||
|
||||
# You can manually trigger a guard with this function
|
||||
def guard_int(self, file, line):
|
||||
# TODO: use the file/line for some useful diagnostic on why a
|
||||
|
|
|
|||
|
|
@ -163,9 +163,6 @@ try:
|
|||
def to_int(x: z3.ArithRef) -> z3.ArithRef:
|
||||
return x if x.is_int() else z3.ToInt(x)
|
||||
|
||||
def sym_sum(self, args: z3.ArithRef) -> z3.ArithRef:
|
||||
return sum(args)
|
||||
|
||||
# Implements Python division semantics.
|
||||
def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
|
||||
self.validator.add_assertion(denominator != 0) # type: ignore[arg-type]
|
||||
|
|
@ -268,10 +265,7 @@ try:
|
|||
@functools.wraps(func)
|
||||
def wrapper(*args):
|
||||
# Lifts the arguments into a list of Z3 inhabitants.
|
||||
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
||||
wrapped_args = (tuple(wrap(a) for a in args[0]),)
|
||||
else:
|
||||
wrapped_args = tuple(wrap(a) for a in args)
|
||||
wrapped_args = (wrap(a) for a in args)
|
||||
# Run the function on the Z3 expressions.
|
||||
return func(*wrapped_args)
|
||||
|
||||
|
|
@ -295,7 +289,6 @@ try:
|
|||
torch.sym_float: lift(ops.to_real),
|
||||
torch.sym_max: lift(ops.max),
|
||||
torch.sym_min: lift(ops.min),
|
||||
torch.sym_sum: lift(ops.sym_sum),
|
||||
torch.sym_ite: lift(lambda b, t, f: t if b else f),
|
||||
torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined]
|
||||
# Not lifted because we only use this function as a
|
||||
|
|
|
|||
|
|
@ -105,16 +105,12 @@ def insert_deferred_runtime_asserts(
|
|||
resolve_unbacked_bindings,
|
||||
)
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
from torch.utils._sympy.reference import (
|
||||
OptimizedPythonReferenceAnalysis,
|
||||
PythonReferenceAnalysis,
|
||||
)
|
||||
from torch.utils._sympy.reference import PythonReferenceAnalysis
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
# TODO: Request simplification on runtime asserts before emitting them
|
||||
ras_by_symbol = shape_env.deferred_runtime_asserts.copy()
|
||||
graph = gm.graph
|
||||
tracer = fx.proxy.GraphAppendingTracer(graph)
|
||||
graph_code_log.debug(
|
||||
"%s",
|
||||
lazy_format_graph_code(
|
||||
|
|
@ -165,12 +161,10 @@ def insert_deferred_runtime_asserts(
|
|||
stack_trace: Optional[str] = None,
|
||||
nn_module_stack: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
fake_args = pytree.tree_map(
|
||||
lambda arg: _get_example_value(arg)
|
||||
if isinstance(arg, torch.fx.Node)
|
||||
else arg,
|
||||
node.args,
|
||||
)
|
||||
fake_args = [
|
||||
_get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg
|
||||
for arg in node.args
|
||||
]
|
||||
try:
|
||||
node.meta[val_key] = node.target(*fake_args) # type: ignore[operator]
|
||||
except NotImplementedError:
|
||||
|
|
@ -187,8 +181,6 @@ def insert_deferred_runtime_asserts(
|
|||
added_asserts: Set[sympy.Expr] = set()
|
||||
constrained_unbacked_symbols: Set[sympy.Symbol] = set()
|
||||
|
||||
Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis
|
||||
|
||||
def _sympy_interp(expr_to_proxy, expr):
|
||||
# sympy_interp() with hash consing
|
||||
from sympy import Integer, Number, Symbol
|
||||
|
|
@ -201,11 +193,11 @@ def insert_deferred_runtime_asserts(
|
|||
return expr_to_proxy[expr]
|
||||
# base cases, don't cache
|
||||
if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)):
|
||||
return sympy_interp(Analysis, expr_to_proxy, expr)
|
||||
return sympy_interp(PythonReferenceAnalysis, expr_to_proxy, expr)
|
||||
|
||||
# hash cons on arguments, run expr handler
|
||||
expr_to_proxy[expr] = _run_sympy_handler(
|
||||
Analysis,
|
||||
PythonReferenceAnalysis,
|
||||
[_sympy_interp(expr_to_proxy, arg) for arg in expr.args],
|
||||
expr,
|
||||
)
|
||||
|
|
@ -289,7 +281,7 @@ def insert_deferred_runtime_asserts(
|
|||
and s not in expr_to_proxy
|
||||
):
|
||||
with _set_node_metadata_hook(gm, _node_metadata_hook):
|
||||
expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer)
|
||||
expr_to_proxy[s] = fx.Proxy(cb())
|
||||
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
|
||||
|
||||
match_symbol(example_value, lambda: node)
|
||||
|
|
@ -395,7 +387,7 @@ def insert_deferred_runtime_asserts(
|
|||
elif sym_expr not in expr_to_proxy and not isinstance(
|
||||
sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
|
||||
): # don't hash cons primitives
|
||||
expr_to_proxy[sym_expr] = fx.Proxy(node, tracer=tracer) # type: ignore[arg-type]
|
||||
expr_to_proxy[sym_expr] = fx.Proxy(node) # type: ignore[arg-type]
|
||||
|
||||
# We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained,
|
||||
# so calls before that are redundant.
|
||||
|
|
@ -488,9 +480,7 @@ def insert_deferred_runtime_asserts(
|
|||
|
||||
if s not in expr_to_proxy:
|
||||
with _set_node_metadata_hook(gm, _node_metadata_hook):
|
||||
expr_to_proxy[s] = fx.Proxy(
|
||||
go(node, keypath), tracer=tracer
|
||||
)
|
||||
expr_to_proxy[s] = fx.Proxy(go(node, keypath))
|
||||
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
|
||||
|
||||
for i0 in defs:
|
||||
|
|
|
|||
|
|
@ -1139,7 +1139,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.sym_min: lambda a, b: -1,
|
||||
torch.sym_not: lambda input: -1,
|
||||
torch.sym_ite: lambda a, b, c: -1,
|
||||
torch.sym_sum: lambda args: -1,
|
||||
torch._sym_sqrt: lambda input: -1,
|
||||
torch._sym_cos: lambda input: -1,
|
||||
torch._sym_cosh: lambda input: -1,
|
||||
|
|
|
|||
|
|
@ -138,12 +138,6 @@ def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64):
|
|||
if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
|
||||
return getattr(analysis, handler_name)(*args, index_dtype)
|
||||
|
||||
# Fastpath for n-ary integral addition
|
||||
if expr.func is sympy.Add and expr.is_integer and hasattr(analysis, "sym_sum"):
|
||||
r = analysis.sym_sum(args)
|
||||
log.debug("sym_sum(%s) -> %s", args, r)
|
||||
return r
|
||||
|
||||
if hasattr(expr.func, "_torch_handler_name"):
|
||||
handler_name = expr.func._torch_handler_name
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -142,10 +142,6 @@ class ReferenceAnalysis:
|
|||
def add(a, b):
|
||||
return _keep_float(operator.add)(a, b)
|
||||
|
||||
@classmethod
|
||||
def sym_sum(cls, args):
|
||||
return sympy.Add(*args)
|
||||
|
||||
@staticmethod
|
||||
def mul(a, b):
|
||||
return _keep_float(operator.mul)(a, b)
|
||||
|
|
@ -210,17 +206,6 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
|
|||
def not_(a):
|
||||
return torch.sym_not(a)
|
||||
|
||||
@classmethod
|
||||
def sym_sum(cls, args):
|
||||
if len(args) == 0:
|
||||
return 0
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
acc = cls.add(args[0], args[1])
|
||||
for i in range(2, len(args)):
|
||||
acc = cls.add(acc, args[i])
|
||||
return acc
|
||||
|
||||
@staticmethod
|
||||
def floordiv(a, b):
|
||||
return a // b
|
||||
|
|
@ -299,14 +284,6 @@ class PythonReferenceAnalysis(ReferenceAnalysis):
|
|||
return round(a, ndigits=b)
|
||||
|
||||
|
||||
# Like PythonReferenceAnalysis, but some export-unfriendly choices of
|
||||
# operators to make things faster
|
||||
class OptimizedPythonReferenceAnalysis(PythonReferenceAnalysis):
|
||||
@staticmethod
|
||||
def sym_sum(args):
|
||||
return torch.sym_sum(args)
|
||||
|
||||
|
||||
def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.ops.prims.convert_element_type.default(x, dtype)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue