From 01cdcbf7c83cd84a3685766f7f8cd26ad447feae Mon Sep 17 00:00:00 2001 From: William Wen Date: Sun, 4 Aug 2024 18:46:55 +0000 Subject: [PATCH] [dynamo] revert map/zip iterator related changes (#132528) Need to revert due to internal hangs: S437700 This reverts commit b6c1490cc02316ffe85e5ae74651d80f0158ba64. Revert "[dynamo] implement IteratorVariable and polyfill fallbacks for enumerate (#131725)" This reverts commit 2576dbbc35d66e8e9ed6cb12216ccc424cb87ec3. Revert "[dynamo] add itertools repeat/count bytecode reconstruction (#131716)" This reverts commit 35b4de32fafc5ad024c20ef1275711bffc557ae9. Revert "[dynamo] add lazy IteratorVariable implementations for map and zip (#131413)" This reverts commit 7d282d87550787d8269593093519c2ad7c5032cd. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/132528 Approved by: https://github.com/ZainRizvi --- test/dynamo/test_functions.py | 218 +--------------------- test/dynamo/test_repros.py | 8 +- torch/_dynamo/polyfill.py | 9 +- torch/_dynamo/symbolic_convert.py | 18 +- torch/_dynamo/utils.py | 8 +- torch/_dynamo/variables/__init__.py | 3 - torch/_dynamo/variables/base.py | 13 -- torch/_dynamo/variables/builtin.py | 139 ++++++-------- torch/_dynamo/variables/constant.py | 8 - torch/_dynamo/variables/dicts.py | 1 - torch/_dynamo/variables/iter.py | 233 +----------------------- torch/_dynamo/variables/lists.py | 33 +--- torch/_dynamo/variables/user_defined.py | 6 +- 13 files changed, 92 insertions(+), 605 deletions(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index b07edf950a2..4c7a9759e4f 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -181,22 +181,6 @@ class FunctionTests(torch._dynamo.test_case.TestCase): v = v + x return v - def test_itertools_reconstruct(self): - def fn(a): - it1 = itertools.repeat(1) - it2 = itertools.count(2) - for _ in range(3): - a += next(it1) - a += next(it2) - return it1, it2, a - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - i1, i2, a = fn(torch.ones(3, 3)) - it1, it2, b = opt_fn(torch.ones(3, 3)) - self.assertEqual(next(i1), next(it1)) - self.assertEqual(next(i2), next(it2)) - self.assertEqual(a, b) - @make_test def test_obj_eq(a, b): v = a + b @@ -449,7 +433,8 @@ class FunctionTests(torch._dynamo.test_case.TestCase): empty = collections.deque() d.extend(empty) - return d + # dynamo same() util doesn't support deque so just return a list + return list(d) @make_test def test_slice1(a): @@ -2886,199 +2871,6 @@ class GraphModule(torch.nn.Module): fn(arr, np.s_[..., 1], np.array([3, 3])), np.array([[1, 3], [2, 3]]) ) - def test_map_return(self): - def fn(a, b): - return map(lambda x: x + 1, [a, b]) - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - m = opt_fn(torch.randn(3, 3), torch.randn(3, 3)) - self.assertIsInstance(m, map) - - @make_test - def test_map_max(a, b): - return max(map(lambda x: x.sum(), [a, b])) - - # max(map(...)) graph breaks - @unittest.expectedFailure - @make_test - def test_map_max_const(a): - return max(map(lambda x: x, [1, 2, 3])), a + 1 - - @make_test - def test_map_list(a, b): - return list(map(lambda x: x + 1, [a, b])) - - @make_test - def test_map_tuple(a, b): - return tuple(map(lambda x: x + 1, [a, b])) - - @make_test - def test_map_iter(a, b): - it = iter(map(lambda x: x + 1, [a, b])) - return next(it) - - @make_test - def test_map_zip_dict(a): - d = dict( - zip( - map(lambda x: x + 1, [0, 1, 2]), - [map(lambda x: x - 1, [y]) for y in [3, 4, 5]], - ) - ) - return list(d[3])[0], a + 1 # noqa: RUF015 - - @make_test - def test_map_dict_fromkeys(a): - return dict.fromkeys(map(lambda x: x + 1, [0, 1])), a + 1 - - @make_test - def test_map_set(a): - return set(map(lambda x: x + 1, [0, 1])), a + 1 - - # test_map_sum defined earlier - - @make_test - def test_map_reduce(a, b): - return functools.reduce(lambda x, y: x + y, map(lambda x: x + 1, [a, b])) - - @make_test - def test_map_sorted(a): - return sorted(map(lambda x: x + 1, [0, 4, 3, 1, 2])), a + 1 - - @make_test - def test_map_list_extend(a, b, c): - l = [a] - l.extend(map(lambda x: x + 1, [b, c])) - return l - - @make_test - def test_map_list_slice_assign(a, b, c, d, e): - l = [a, b, c] - l[1:2] = map(lambda x: x + 1, [d, e]) - return l - - @make_test - def test_map_deque_extendleft(a, b, c): - d = collections.deque([a]) - d.extendleft(map(lambda x: x + 1, [b, c])) - return d - - @make_test - def test_map_str_join(a): - return "".join(map(lambda x: x, ["a", "b", "c"])), a + 1 - - def test_map_with_graph_break(self): - def f(a): - a += 1 - - def g(x): - nonlocal a - a += 1 - return x + 1 - - m = map(g, [1, 2, 3, 4, 5]) - a += next(m) # won't graph break - torch._dynamo.graph_break() - a += next(m) # will graph break - return a - - cnts = torch._dynamo.testing.CompileCounter() - opt_f = torch.compile(f, backend=cnts) - self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3))) - self.assertEqual(cnts.frame_count, 3) - - def test_map_reconstruct(self): - def fn(a): - return map(lambda x: x[0] + x[1], zip([1, 2, 3], [1, 2, 3])), a + 1 - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - m = opt_fn(torch.ones(3, 3))[0] - self.assertIsInstance(m, map) - self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) - - def test_zip_reconstruct(self): - def fn(a): - return zip([1, 2, 3], map(lambda x: x + 1, [1, 2, 3])), a + 1 - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - m = opt_fn(torch.ones(3, 3))[0] - self.assertIsInstance(m, zip) - self.assertEqual(list(m), list(fn(torch.ones(3, 3))[0])) - - @make_test - def test_map_partial_unpack(a, b): - y = 1 - - def f(x): - nonlocal y - y += 1 - return x - - l = list(zip([a, b], map(f, [1, 2, 3, 4]))) - return a + y - - @make_test - def test_map_call_function_ex(a, b): - def f(x, y): - return x + y - - return f(*map(lambda x: x + 1, [a, b])) - - @make_test - def test_map_unpack_twice(a, b): - m = map(lambda x: x + 1, [a, b]) - l1 = list(m) - l2 = list(m) - return l1, l2 - - @make_test - def test_enumerate(a, b): - return list(enumerate([a, b], start=1)), a + 1 - - @make_test - def test_map_enumerate(a, b): - return list(enumerate(map(lambda x: x + 1, [a, b]), start=1)), a + 1 - - @make_test - def test_map_infinite(a, b): - return list(map(lambda x, y: x + y, [a, b], itertools.count(3))) - - @make_test - def test_map_unpack_vars(a, b): - x, y = map(lambda x: x + 1, [a, b]) - return x + y - - def test_enumerate_custom(self): - class MyClass: - def __iter__(self): - self.a = 1 - return self - - def __next__(self): - if self.a > 3: - raise StopIteration - self.a += 1 - return self.a - - def fn(x): - for i, it in enumerate(MyClass()): - x += i + it - return x - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - self.assertEqual(fn(torch.ones(3, 3)), opt_fn(torch.ones(3, 3))) - - def test_enumerate_reconstruct(self): - def fn(a, b): - return enumerate([a, b], start=1) - - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - inps = (torch.randn(3, 3), torch.randn(3, 3)) - it1 = fn(*inps) - it2 = opt_fn(*inps) - self.assertIsInstance(it2, enumerate) - self.assertEqual(list(it1), list(it2)) - def udf_mul(x, y): return x * y @@ -3569,16 +3361,10 @@ class DefaultsTests(torch._dynamo.test_case.TestCase): with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): nopython_fn(x, ys[:1], zs) - with self.assertRaisesRegex(torch._dynamo.exc.UserError, "zip()"): - nopython_fn(x, ys, zs[:1]) - # Should cause fallback if allow graph break with self.assertRaisesRegex(ValueError, "zip()"): opt_fn(x, ys[:1], zs) - with self.assertRaisesRegex(ValueError, "zip()"): - opt_fn(x, ys, zs[:1]) - def test_fn_with_attr(self): def fn(x): if fn.pred: diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 9fc16db9de5..d387ed29bf3 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -5400,17 +5400,15 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): return x, y def g(x, y): - return map(f, x, y) + return tuple(map(f, x, y)) opt_g = torch.compile(g, fullgraph=True, backend="eager") inps = gen_inps(3, 3) - self.assertEqual(type(g(*inps)), type(opt_g(*inps))) - self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) + self.assertEqual(g(*inps), opt_g(*inps)) inps = gen_inps(3, 5) - self.assertEqual(type(g(*inps)), type(opt_g(*inps))) - self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) + self.assertEqual(g(*inps), opt_g(*inps)) def test_staticmethod_allow_in_graph(self): class MyClass: diff --git a/torch/_dynamo/polyfill.py b/torch/_dynamo/polyfill.py index cedfbb73885..7a8091c6282 100644 --- a/torch/_dynamo/polyfill.py +++ b/torch/_dynamo/polyfill.py @@ -24,7 +24,7 @@ def any(iterator): def index(iterator, item, start=0, end=None): - for i, elem in list(enumerate(list(iterator)))[start:end]: + for i, elem in list(enumerate(iterator))[start:end]: if item == elem: return i # This will not run in dynamo @@ -126,13 +126,6 @@ def getattr_and_trace(*args, **kwargs): return fn(*args[2:], **kwargs) -def enumerate(iterable, start=0): - n = start - for elem in iterable: - yield n, elem - n += 1 - - def mapping_get(obj, key, value=None): try: return obj.__getitem__(key) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 6cac5c17df2..93f5f30c191 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1625,8 +1625,8 @@ class InstructionTranslatorBase( if not isinstance( argsvars, BaseListVariable - ) and argsvars.has_force_unpack_var_sequence(self): - argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self)) + ) and argsvars.has_unpack_var_sequence(self): + argsvars = TupleVariable(argsvars.unpack_var_sequence(self)) # Unpack for cases like fn(**obj) where obj is a map if isinstance(kwargsvars, UserDefinedObjectVariable): @@ -1795,7 +1795,7 @@ class InstructionTranslatorBase( items = [] for seq in seqs: try: - items.extend(seq.force_unpack_var_sequence(self)) + items.extend(seq.unpack_var_sequence(self)) except NotImplementedError: unimplemented(f"BUILD_LIST_UNPACK {seq}") self.push(cls(items, mutable_local=MutableLocal())) @@ -1833,7 +1833,7 @@ class InstructionTranslatorBase( assert isinstance(keys, TupleVariable) assert keys.is_python_constant() - keys = keys.force_unpack_var_sequence(self) + keys = keys.unpack_var_sequence(self) assert len(keys) == len(values) self.push( @@ -1923,8 +1923,8 @@ class InstructionTranslatorBase( # x, y = a.shape proxy = getattr(seq.obj.as_proxy(), seq.name) val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)] - elif seq.has_force_unpack_var_sequence(self): - val = seq.force_unpack_var_sequence(self) + elif seq.has_unpack_var_sequence(self): + val = seq.unpack_var_sequence(self) else: unimplemented(f"UNPACK_SEQUENCE {seq}") if len(val) != inst.argval: @@ -1937,8 +1937,8 @@ class InstructionTranslatorBase( prefix = inst.argval & 0xFF # low byte suffix = inst.argval >> 8 # high byte seq = self.pop() - if seq.has_force_unpack_var_sequence(self): - vals = list(seq.force_unpack_var_sequence(self)) + if seq.has_unpack_var_sequence(self): + vals = list(seq.unpack_var_sequence(self)) assert len(vals) >= prefix + suffix vals_prefix = vals[:prefix] vals_list = vals[prefix : len(vals) - suffix] @@ -2362,7 +2362,7 @@ class InstructionTranslatorBase( self.UNARY_POSITIVE(inst) elif inst.argval == 6: # INTRINSIC_LIST_TO_TUPLE - self.push(TupleVariable(self.pop().force_unpack_var_sequence(self))) + self.push(TupleVariable(self.pop().unpack_var_sequence(self))) else: unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}") diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 58f363e02c1..1174862b3fe 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1387,12 +1387,8 @@ def same( """Check correctness to see if ref and res match""" if fp64_ref is None: fp64_ref = ref - if isinstance( - ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size) - ): - assert isinstance( - res, (list, tuple, collections.deque) - ), f"type mismatch {type(ref)} {type(res)}" + if isinstance(ref, (list, tuple, torch.nn.ParameterList, torch.Size)): + assert isinstance(res, (list, tuple)), f"type mismatch {type(ref)} {type(res)}" if len(ref) != len(res): log_error("Length mismatch") return False diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index e880d705fca..ffce6a638bd 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -41,12 +41,9 @@ from .higher_order_ops import ( from .iter import ( CountIteratorVariable, CycleIteratorVariable, - EnumerateVariable, IteratorVariable, ItertoolsVariable, - MapVariable, RepeatIteratorVariable, - ZipVariable, ) from .lazy import LazyVariableTracker from .lists import ( diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 8064a1a462e..e7a3b7320e5 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -286,15 +286,6 @@ class VariableTracker(metaclass=VariableTrackerMeta): def unpack_var_sequence(self, tx) -> List["VariableTracker"]: raise NotImplementedError - def force_unpack_var_sequence(self, tx) -> List["VariableTracker"]: - # like unpack_var_sequence, but should only be used when it is - # safe to eagerly (vs. lazily) unpack this variable. - # e.g. map(f, x) is normally evaluated lazily but sometimes - # we want to force eager unpacking, e.g. when converting to a list. - # NOTE: this method is allowed to mutate the VariableTracker, so - # it should only be called once. - return self.unpack_var_sequence(tx) - def has_unpack_var_sequence(self, tx) -> bool: try: self.unpack_var_sequence(tx) @@ -302,10 +293,6 @@ class VariableTracker(metaclass=VariableTrackerMeta): except NotImplementedError: return False - # NB: don't call force_unpack_var_sequence, especially if it mutates! - def has_force_unpack_var_sequence(self, tx) -> bool: - return self.has_unpack_var_sequence(tx) - def inspect_parameter_names(self) -> List[str]: unimplemented(f"inspect_parameter_names: {self}") diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index d7c07467e6d..f03c8b5028e 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1076,8 +1076,9 @@ class BuiltinVariable(VariableTracker): return tx.inline_user_function_return(user_func_variable, [arg], {}) def _call_min_max(self, tx: "InstructionTranslator", *args): - if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): - items = args[0].force_unpack_var_sequence(tx) + if len(args) == 1 and args[0].has_unpack_var_sequence(tx): + # expand iterable + items = args[0].unpack_var_sequence(tx) return self._call_min_max_seq(tx, items) elif len(args) == 2: return self._call_min_max_binary(tx, args[0], args[1]) @@ -1092,10 +1093,6 @@ class BuiltinVariable(VariableTracker): return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) def _call_min_max_binary(self, tx: "InstructionTranslator", a, b): - if a is None or b is None: - # a or b could be none if we reduce and _call_min_max_binary failed - # to return something - return if self.tensor_args(a, b): if not isinstance(a, variables.TensorVariable): a, b = b, a @@ -1244,15 +1241,17 @@ class BuiltinVariable(VariableTracker): ), ) - # NOTE must handle IteratorVariable separately! def _call_iter_tuple_list( self, tx: "InstructionTranslator", obj=None, *args, **kwargs ): - assert not isinstance(obj, variables.IteratorVariable) - if self._dynamic_args(*args, **kwargs): return self._dyn_proxy(tx, *args, **kwargs) + if isinstance(obj, variables.IteratorVariable): + # For non-list iterators, we will guard on vars that + # determine the control flow + return obj + cls = variables.BaseListVariable.cls_for(self.fn) if obj is None: return cls( @@ -1280,22 +1279,9 @@ class BuiltinVariable(VariableTracker): mutable_local=MutableLocal(), ) - def _call_tuple_list(self, tx, obj=None, *args, **kwargs): - if isinstance(obj, variables.IteratorVariable): - cls = variables.BaseListVariable.cls_for(self.fn) - return cls( - list(obj.force_unpack_var_sequence(tx)), - mutable_local=MutableLocal(), - ) - else: - return self._call_iter_tuple_list(tx, obj, *args, **kwargs) - def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): - if isinstance(obj, variables.IteratorVariable): - ret = obj - else: - # Handle the case where we are iterating over a tuple, list or iterator - ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) + # Handle the case where we are iterating over a tuple, list or iterator + ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) if ret is None: # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. @@ -1304,8 +1290,8 @@ class BuiltinVariable(VariableTracker): return obj.call_method(tx, "__iter__", args, kwargs) return ret - call_tuple = _call_tuple_list - call_list = _call_tuple_list + call_tuple = _call_iter_tuple_list + call_list = _call_iter_tuple_list def call_callable(self, tx: "InstructionTranslator", arg): from .functions import BaseUserFunctionVariable @@ -1363,12 +1349,10 @@ class BuiltinVariable(VariableTracker): ListVariable, TupleVariable, ListIteratorVariable, - variables.IteratorVariable, ), ): items = dict( - x.force_unpack_var_sequence(tx) - for x in arg.force_unpack_var_sequence(tx) + x.unpack_var_sequence(tx) for x in arg.unpack_var_sequence(tx) ) return ConstDictVariable(items, user_cls, mutable_local=MutableLocal()) elif isinstance(arg, variables.MutableMappingVariable): @@ -1425,12 +1409,13 @@ class BuiltinVariable(VariableTracker): return DictVariableType( dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal() ) - elif arg.has_force_unpack_var_sequence(tx): - keys = arg.force_unpack_var_sequence(tx) - if all(is_hashable(v) for v in keys): - return DictVariableType( - dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() - ) + elif arg.has_unpack_var_sequence(tx) and all( + is_hashable(v) for v in arg.unpack_var_sequence(tx) + ): + keys = arg.unpack_var_sequence(tx) + return DictVariableType( + dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() + ) unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}") def call_set(self, tx: "InstructionTranslator", *args, **kwargs): @@ -1442,8 +1427,8 @@ class BuiltinVariable(VariableTracker): arg = args[0] if isinstance(arg, variables.SetVariable): return arg.clone(mutable_local=MutableLocal()) - elif arg.has_force_unpack_var_sequence(tx): - items = arg.force_unpack_var_sequence(tx) + elif arg.has_unpack_var_sequence(tx): + items = arg.unpack_var_sequence(tx) return SetVariable(items, mutable_local=MutableLocal()) elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( arg.value, KeysView @@ -1462,36 +1447,32 @@ class BuiltinVariable(VariableTracker): def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): if kwargs: assert len(kwargs) == 1 and "strict" in kwargs - strict = kwargs.pop("strict", False) - args = [ - arg.unpack_var_sequence(tx) if arg.has_unpack_var_sequence(tx) else arg - for arg in args - ] - return variables.ZipVariable(args, strict=strict, mutable_local=MutableLocal()) + if all(x.has_unpack_var_sequence(tx) for x in args): + unpacked = [arg.unpack_var_sequence(tx) for arg in args] + if kwargs.pop("strict", False) and len(unpacked) > 0: + if not all(len(u) == len(unpacked[0]) for u in unpacked): + raise UserError( + ValueError, + "zip() has one argument of len differing from others", + ) + items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)] + return variables.TupleVariable(items) - _call_enumerate_polyfill = _polyfill_call_impl("enumerate") - - def call_enumerate(self, tx: "InstructionTranslator", iterable, start=_SENTINEL): - if start is self._SENTINEL: + def call_enumerate(self, tx: "InstructionTranslator", *args): + if len(args) == 1: start = 0 else: - assert isinstance(start, variables.ConstantVariable) - start = start.as_python_constant() - - if iterable.has_unpack_var_sequence(tx): - return variables.EnumerateVariable( - iterable.unpack_var_sequence(tx), - start, - mutable_local=MutableLocal(), - ) - elif isinstance(iterable, variables.IteratorVariable): - return variables.EnumerateVariable( - iterable, start, mutable_local=MutableLocal() - ) - - return self._call_enumerate_polyfill( - tx, iterable, variables.ConstantVariable.create(start) - ) + assert len(args) == 2 + assert isinstance(args[1], variables.ConstantVariable) + start = args[1].as_python_constant() + if args[0].has_unpack_var_sequence(tx): + items = [ + variables.TupleVariable( + [variables.ConstantVariable.create(idx), var], + ) + for idx, var in enumerate(args[0].unpack_var_sequence(tx), start) + ] + return variables.TupleVariable(items) def call_len(self, tx: "InstructionTranslator", *args, **kwargs): return args[0].call_method(tx, "__len__", args[1:], kwargs) @@ -1592,11 +1573,10 @@ class BuiltinVariable(VariableTracker): return obj.call_hasattr(tx, name) def call_map(self, tx: "InstructionTranslator", fn, *seqs): - seqs = [ - seq.unpack_var_sequence(tx) if seq.has_unpack_var_sequence(tx) else seq - for seq in seqs - ] - return variables.MapVariable(fn, seqs, mutable_local=MutableLocal()) + if all(seq.has_unpack_var_sequence(tx) for seq in seqs): + unpacked = [seq.unpack_var_sequence(tx) for seq in seqs] + items = [fn.call_function(tx, list(args), {}) for args in zip(*unpacked)] + return variables.TupleVariable(items) def call_sum(self, tx: "InstructionTranslator", seq, start=_SENTINEL): # Special case for sum on tuple of floats and ints @@ -1615,10 +1595,10 @@ class BuiltinVariable(VariableTracker): return variables.ConstantVariable.create( sum((x.value for x in seq.items), start=start.value), ) - if seq.has_force_unpack_var_sequence(tx): + if seq.has_unpack_var_sequence(tx): if start is self._SENTINEL: start = variables.ConstantVariable.create(0) - items = seq.force_unpack_var_sequence(tx) + items = seq.unpack_var_sequence(tx) return BuiltinVariable(functools.reduce).call_function( tx, [ @@ -1632,8 +1612,8 @@ class BuiltinVariable(VariableTracker): def call_reduce( self, tx: "InstructionTranslator", function, iterable, initial=_SENTINEL ): - if iterable.has_force_unpack_var_sequence(tx): - items = iterable.force_unpack_var_sequence(tx) + if iterable.has_unpack_var_sequence(tx): + items = iterable.unpack_var_sequence(tx) if initial is self._SENTINEL: value, items = items[0], items[1:] else: @@ -1920,12 +1900,11 @@ class BuiltinVariable(VariableTracker): return variables.TupleVariable(items) def call_sorted(self, tx: "InstructionTranslator", obj: VariableTracker, **kwargs): - if obj.has_force_unpack_var_sequence(tx) and not isinstance( - obj, variables.TensorVariable + if ( + obj.has_unpack_var_sequence(tx) + and not isinstance(obj, variables.TensorVariable) + and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx)) ): - unpacked = obj.force_unpack_var_sequence(tx) - if not all(x.is_python_constant() for x in unpacked): - return function = kwargs.pop("key", None) reverse = kwargs.pop( "reverse", ConstantVariable.create(False) @@ -1933,7 +1912,7 @@ class BuiltinVariable(VariableTracker): assert len(kwargs) == 0 if function: items = sorted( - unpacked, + obj.unpack_var_sequence(tx), key=lambda x: function.call_function( tx, [x], {} ).as_python_constant(), @@ -1941,7 +1920,7 @@ class BuiltinVariable(VariableTracker): ) else: items = sorted( - unpacked, + obj.unpack_var_sequence(tx), key=lambda x: x.as_python_constant(), reverse=reverse, ) diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 28cdef19b96..f1803e57401 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -147,14 +147,6 @@ class ConstantVariable(VariableTracker): return variables.BuiltinVariable(str.format).call_function( tx, [self, *args], kwargs ) - elif name == "join" and istype(self.value, str): - assert len(args) == 1 and len(kwargs) == 0 - arg_unpacked = args[0].force_unpack_var_sequence(tx) - try: - arg_const = [x.as_python_constant() for x in arg_unpacked] - return ConstantVariable.create(self.value.join(arg_const)) - except NotImplementedError: - return super().call_method(tx, name, args, kwargs) if any(isinstance(x, SymNodeVariable) for x in args): # Promote to SymNodeVariable for operations involving dynamic shapes. diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 1907e32b190..1013ab71f2d 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -314,7 +314,6 @@ class ConstDictVariable(VariableTracker): ListVariable, TupleVariable, ListIteratorVariable, - variables.IteratorVariable, UserDefinedObjectVariable, ), ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 9c812b0ae00..f02052be70f 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -2,17 +2,14 @@ import itertools import operator -import sys -from typing import Dict, List, Optional, TYPE_CHECKING, Union +from typing import Dict, List, Optional, TYPE_CHECKING from .. import polyfill, variables -from ..bytecode_transformation import create_call_function, create_instruction from ..exc import ( handle_observed_exception, ObservedUserStopIteration, raise_observed_exception, unimplemented, - UserError, ) from .base import MutableLocal, VariableTracker from .constant import ConstantVariable @@ -60,7 +57,6 @@ class ItertoolsVariable(VariableTracker): and not kwargs and all(arg.has_unpack_var_sequence(tx) for arg in args) ): - # TODO support itertools.chain with arbitrary iterables seqs = [arg.unpack_var_sequence(tx) for arg in args] items = list(itertools.chain.from_iterable(seqs)) return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) @@ -212,25 +208,6 @@ class IteratorVariable(VariableTracker): def next_variable(self, tx): unimplemented("abstract method, must implement") - # NOTE: only call when unpacking this iterator safely done eagerly! - # Normally, iterators are accessed lazily. - # Example of safe eager unpacking: list(map(f, seq)) - # Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) - def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: - result = [] - while True: - try: - result.append(self.next_variable(tx)) - except ObservedUserStopIteration: - handle_observed_exception(tx) - break - return result - - # don't call force_unpack_var_sequence since it can mutate - # IteratorVariable state! - def has_force_unpack_var_sequence(self, tx) -> bool: - return True - class RepeatIteratorVariable(IteratorVariable): def __init__(self, item: VariableTracker, **kwargs) -> None: @@ -241,18 +218,6 @@ class RepeatIteratorVariable(IteratorVariable): def next_variable(self, tx): return self.item - def reconstruct(self, codegen): - codegen.add_push_null( - lambda: codegen.extend_output( - [ - codegen.create_load_python_module(itertools), - codegen.create_load_attr("repeat"), - ] - ) - ) - codegen(self.item) - codegen.extend_output(create_call_function(1, False)) - class CountIteratorVariable(IteratorVariable): def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: @@ -266,23 +231,10 @@ class CountIteratorVariable(IteratorVariable): def next_variable(self, tx): assert self.mutable_local - old_item = self.item tx.output.side_effects.mutation(self) - self.item = self.item.call_method(tx, "__add__", [self.step], {}) - return old_item - - def reconstruct(self, codegen): - codegen.add_push_null( - lambda: codegen.extend_output( - [ - codegen.create_load_python_module(itertools), - codegen.create_load_attr("count"), - ] - ) - ) - codegen(self.item) - codegen(self.step) - codegen.extend_output(create_call_function(2, False)) + next_item = self.item.call_method(tx, "__add__", [self.step], {}) + self.item = next_item + return self.item class CycleIteratorVariable(IteratorVariable): @@ -328,180 +280,3 @@ class CycleIteratorVariable(IteratorVariable): return self.item else: raise_observed_exception(StopIteration, tx, self) - - -class ZipVariable(IteratorVariable): - """ - Represents zip(*iterables) - """ - - _nonvar_fields = { - "index", - "strict", - *IteratorVariable._nonvar_fields, - } - - def __init__( - self, - iterables: List[Union[List[VariableTracker], VariableTracker]], - strict: bool = False, - **kwargs, - ) -> None: - super().__init__(**kwargs) - assert isinstance(iterables, list) - # can be list[Variable] or VariableTracker (with next_variable implemented) - self.iterables = iterables - self.index = 0 - self.strict = strict - - def python_type(self): - return zip - - def has_unpack_var_sequence(self, tx) -> bool: - return all( - isinstance(it, list) or it.has_unpack_var_sequence(tx) - for it in self.iterables - ) - - def unpack_var_sequence(self, tx) -> List["VariableTracker"]: - assert self.has_unpack_var_sequence(tx) - iterables = [] - for it in self.iterables: - if isinstance(it, list): - iterables.append(it[self.index :]) - else: - iterables.append(it.unpack_var_sequence(tx)) - kwargs = {"strict": self.strict} if self.strict else {} - zipped = zip(*iterables, **kwargs) - return [variables.TupleVariable(list(var)) for var in zipped] - - def next_variable(self, tx): - assert self.mutable_local - old_index = self.index - args = [] - - def get_item(it): - if isinstance(it, list): - if old_index >= len(it): - raise_observed_exception(StopIteration, tx, self) - return it[old_index] - else: - return it.next_variable(tx) - - try: - for idx, it in enumerate(self.iterables): - args.append(get_item(it)) - except ObservedUserStopIteration: - if self.strict: - if idx == 0: - # all other iterables should be exhausted - for it in self.iterables: - try: - get_item(it) - except ObservedUserStopIteration: - handle_observed_exception(tx) - continue - # no ObservedUserStopIteration - fall through to UserError - break - else: - # all iterables exhausted, raise original error - raise - handle_observed_exception(tx) - raise UserError( - ValueError, - "zip() has one argument of len differing from others", - ) from None - raise - - tx.output.side_effects.mutation(self) - self.index += 1 - return variables.TupleVariable(args) - - def reconstruct_items(self, codegen): - for it in self.iterables: - if isinstance(it, list): - remaining_items = it[self.index :] - codegen.foreach(remaining_items) - codegen.append_output( - create_instruction("BUILD_TUPLE", arg=len(remaining_items)) - ) - else: - codegen(it) - - def reconstruct(self, codegen): - codegen.add_push_null( - lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True - ) - self.reconstruct_items(codegen) - codegen.append_output( - create_instruction("BUILD_TUPLE", arg=len(self.iterables)) - ) - if sys.version_info >= (3, 10): - codegen.extend_output( - [ - codegen.create_load_const("strict"), - codegen.create_load_const(self.strict), - create_instruction("BUILD_MAP", arg=1), - create_instruction("CALL_FUNCTION_EX", arg=1), - ] - ) - else: - codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0)) - - -class MapVariable(ZipVariable): - """ - Represents map(fn, *iterables) - """ - - def __init__( - self, - fn: VariableTracker, - iterables: List[Union[List[VariableTracker], VariableTracker]], - **kwargs, - ) -> None: - super().__init__(iterables, **kwargs) - self.fn = fn - - def python_type(self): - return map - - def has_unpack_var_sequence(self, tx) -> bool: - return False - - def next_variable(self, tx): - args = super().next_variable(tx) - return self.fn.call_function(tx, args.items, {}) - - def reconstruct(self, codegen): - codegen.add_push_null( - lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True - ) - codegen(self.fn) - self.reconstruct_items(codegen) - codegen.extend_output( - [ - create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1), - create_instruction("CALL_FUNCTION_EX", arg=0), - ] - ) - - -class EnumerateVariable(ZipVariable): - def __init__( - self, - iterable: Union[List[VariableTracker], VariableTracker], - start: int = 0, - **kwargs, - ) -> None: - super().__init__( - [CountIteratorVariable(start, mutable_local=MutableLocal()), iterable], - **kwargs, - ) - - def reconstruct(self, codegen): - codegen.add_push_null(lambda: codegen.load_import_from("builtins", "enumerate")) - codegen(self.iterables[1]) - assert isinstance(self.iterables[0], CountIteratorVariable) - codegen(self.iterables[0].item) - codegen.extend_output(codegen.create_call_function_kw(2, ("start",), False)) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 66237b84ad4..3a4b96f0df9 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -29,7 +29,6 @@ from ..utils import ( from .base import MutableLocal, VariableTracker from .constant import ConstantVariable from .functions import UserFunctionVariable, UserMethodVariable -from .iter import IteratorVariable if TYPE_CHECKING: @@ -340,11 +339,11 @@ class CommonListMethodsVariable(BaseListVariable): name == "extend" and self.mutable_local and args - and args[0].has_force_unpack_var_sequence(tx) + and args[0].has_unpack_var_sequence(tx) ): assert not kwargs (arg,) = args - seq = arg.force_unpack_var_sequence(tx) + seq = arg.unpack_var_sequence(tx) tx.output.side_effects.mutation(self) self.items.extend(seq) return ConstantVariable.create(None) @@ -428,13 +427,11 @@ class ListVariable(CommonListMethodsVariable): key, value = args tx.output.side_effects.mutation(self) if isinstance(key, SliceVariable): - if not value.has_force_unpack_var_sequence(tx): + if not value.has_unpack_var_sequence(tx): unimplemented( f"Missing dynamo support for expanding {value} into a list for slice assignment." ) - self.items[key.as_python_constant()] = value.force_unpack_var_sequence( - tx - ) + self.items[key.as_python_constant()] = value.unpack_var_sequence(tx) else: self.items[key.as_python_constant()] = value return ConstantVariable.create(None) @@ -462,12 +459,7 @@ class DequeVariable(CommonListMethodsVariable): ) ) codegen.foreach(self.items) - codegen.extend_output( - [ - create_instruction("BUILD_LIST", arg=len(self.items)), - *create_call_function(1, False), - ] - ) + codegen.extend_output(create_call_function(len(self.items), False)) def call_method( self, @@ -490,15 +482,11 @@ class DequeVariable(CommonListMethodsVariable): tx.output.side_effects.mutation(self) self.items[key.as_python_constant()] = value return ConstantVariable.create(None) - elif ( - name == "extendleft" - and self.mutable_local - and args[0].has_force_unpack_var_sequence(tx) - ): + elif name == "extendleft" and self.mutable_local: assert not kwargs (arg,) = args - prefix = arg.force_unpack_var_sequence(tx) + prefix = arg.unpack_var_sequence(tx) prefix.reverse() tx.output.side_effects.mutation(self) self.items = prefix + list(self.items) @@ -796,10 +784,10 @@ class SliceVariable(BaseListVariable): return self.items[fields.index(name)] -class ListIteratorVariable(IteratorVariable): +class ListIteratorVariable(VariableTracker): _nonvar_fields = { "index", - *IteratorVariable._nonvar_fields, + *VariableTracker._nonvar_fields, } def __init__(self, items, index: int = 0, **kwargs) -> None: @@ -850,9 +838,6 @@ class ListIteratorVariable(IteratorVariable): def unpack_var_sequence(self, tx): return list(self.items[self.index :]) - def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: - return self.unpack_var_sequence(tx) - def reconstruct(self, codegen): remaining_items = self.items[self.index :] codegen.foreach(remaining_items) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 1b43ff83579..cb2e131edb3 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -353,8 +353,8 @@ class UserDefinedClassVariable(UserDefinedVariable): elif self.value is collections.deque and not kwargs: if len(args) == 0: items = [] - elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): - items = args[0].force_unpack_var_sequence(tx) + elif len(args) == 1 and args[0].has_unpack_var_sequence(tx): + items = args[0].unpack_var_sequence(tx) else: unimplemented("deque() with more than 1 arg not supported") return variables.lists.DequeVariable(items, mutable_local=MutableLocal()) @@ -654,7 +654,7 @@ class UserDefinedObjectVariable(UserDefinedVariable): assert not (args or kwargs) items = [] keys = self.call_method(tx, "keys", [], {}) - for key in keys.force_unpack_var_sequence(tx): + for key in keys.unpack_var_sequence(tx): items.append( TupleVariable( [key, self.odict_getitem(tx, key)],