Implement generator.__iter__() (#144421)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144421
Approved by: https://github.com/zou3519
ghstack dependencies: #141055
This commit is contained in:
Guilherme Leobas 2025-02-07 14:55:19 -03:00 committed by PyTorch MergeBot
parent 8603a1c870
commit d798831167
2 changed files with 23 additions and 7 deletions

View file

@ -279,7 +279,6 @@ class GraphModule(torch.nn.Module):
expected = list(zip(range(3), whoo(t)))
self.assertEqual(expected, list(y))
@unittest.expectedFailure
def test_zip_subgenerator(self):
def subgen(t):
yield t + 1
@ -329,9 +328,6 @@ class GraphModule(torch.nn.Module):
@parametrize("container", [list, tuple, dict, OrderedDict])
def test_dict_tuple_list_generator(self, container):
if container in (dict, OrderedDict):
self.skipTest("Needs __iter__")
def whoo(t):
yield 1, t + 1
yield 2, t + 2
@ -407,7 +403,6 @@ class GraphModule(torch.nn.Module):
with self.assertRaises(StopIteration):
next(gen)
@unittest.expectedFailure
def test_subgenerator(self):
def subgen(t):
yield t + 1
@ -426,7 +421,6 @@ class GraphModule(torch.nn.Module):
y = fn(t)
self.assertEqual(y, [t + 1, t + 2, t + 3])
@unittest.expectedFailure
def test_return_subgenerator(self):
def subgen(t):
yield t + 1
@ -598,6 +592,26 @@ class GraphModule(torch.nn.Module):
self.assertEqual(i, 3)
self.assertEqual(y, [(0, t), (1, t + 1), (2, t + 2)])
def test_iter(self):
def whoo():
i = 0
while True:
yield i
i += 1
@torch.compile(backend="eager", fullgraph=True)
def fn(t):
s = 0
for i in whoo():
if i > 5:
break
s += i
return t + s
t = torch.randn(2)
y = fn(t)
self.assertEqual(y, t + sum(range(6)))
class GeneratorCPythonTests(GeneratorTestsBase):
# Taken from commit
@ -625,7 +639,6 @@ class GeneratorCPythonTests(GeneratorTestsBase):
self._compile_check(fn)
@unittest.expectedFailure
def test_issue103488(self):
def gen_raises():
yield 1

View file

@ -499,6 +499,9 @@ class LocalGeneratorObjectVariable(VariableTracker):
) -> "VariableTracker":
if name == "__next__":
return self.next_variable(tx)
elif name == "__iter__":
# iter(gen) returns itself
return self
super().call_method(tx, name, args, kwargs)