mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Implement generator.__iter__() (#144421)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144421 Approved by: https://github.com/zou3519 ghstack dependencies: #141055
This commit is contained in:
parent
8603a1c870
commit
d798831167
2 changed files with 23 additions and 7 deletions
|
|
@ -279,7 +279,6 @@ class GraphModule(torch.nn.Module):
|
|||
expected = list(zip(range(3), whoo(t)))
|
||||
self.assertEqual(expected, list(y))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_zip_subgenerator(self):
|
||||
def subgen(t):
|
||||
yield t + 1
|
||||
|
|
@ -329,9 +328,6 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
@parametrize("container", [list, tuple, dict, OrderedDict])
|
||||
def test_dict_tuple_list_generator(self, container):
|
||||
if container in (dict, OrderedDict):
|
||||
self.skipTest("Needs __iter__")
|
||||
|
||||
def whoo(t):
|
||||
yield 1, t + 1
|
||||
yield 2, t + 2
|
||||
|
|
@ -407,7 +403,6 @@ class GraphModule(torch.nn.Module):
|
|||
with self.assertRaises(StopIteration):
|
||||
next(gen)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_subgenerator(self):
|
||||
def subgen(t):
|
||||
yield t + 1
|
||||
|
|
@ -426,7 +421,6 @@ class GraphModule(torch.nn.Module):
|
|||
y = fn(t)
|
||||
self.assertEqual(y, [t + 1, t + 2, t + 3])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_return_subgenerator(self):
|
||||
def subgen(t):
|
||||
yield t + 1
|
||||
|
|
@ -598,6 +592,26 @@ class GraphModule(torch.nn.Module):
|
|||
self.assertEqual(i, 3)
|
||||
self.assertEqual(y, [(0, t), (1, t + 1), (2, t + 2)])
|
||||
|
||||
def test_iter(self):
|
||||
def whoo():
|
||||
i = 0
|
||||
while True:
|
||||
yield i
|
||||
i += 1
|
||||
|
||||
@torch.compile(backend="eager", fullgraph=True)
|
||||
def fn(t):
|
||||
s = 0
|
||||
for i in whoo():
|
||||
if i > 5:
|
||||
break
|
||||
s += i
|
||||
return t + s
|
||||
|
||||
t = torch.randn(2)
|
||||
y = fn(t)
|
||||
self.assertEqual(y, t + sum(range(6)))
|
||||
|
||||
|
||||
class GeneratorCPythonTests(GeneratorTestsBase):
|
||||
# Taken from commit
|
||||
|
|
@ -625,7 +639,6 @@ class GeneratorCPythonTests(GeneratorTestsBase):
|
|||
|
||||
self._compile_check(fn)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_issue103488(self):
|
||||
def gen_raises():
|
||||
yield 1
|
||||
|
|
|
|||
|
|
@ -499,6 +499,9 @@ class LocalGeneratorObjectVariable(VariableTracker):
|
|||
) -> "VariableTracker":
|
||||
if name == "__next__":
|
||||
return self.next_variable(tx)
|
||||
elif name == "__iter__":
|
||||
# iter(gen) returns itself
|
||||
return self
|
||||
|
||||
super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue