diff --git a/test/test_datapipe.py b/test/test_datapipe.py index ab4c979a591..9b7f22eba74 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -2197,7 +2197,7 @@ class TestSerialization(TestCase): dl = DataLoader(idp, num_workers=2, shuffle=True, multiprocessing_context='spawn', collate_fn=unbatch, batch_size=1) result = list(dl) - self.assertEquals([1, 1, 2, 2, 3, 3], sorted(result)) + self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result)) @skipIfNoDill def test_spawn_lambdas_map(self): @@ -2205,7 +2205,7 @@ class TestSerialization(TestCase): dl = DataLoader(mdp, num_workers=2, shuffle=True, multiprocessing_context='spawn', collate_fn=unbatch, batch_size=1) result = list(dl) - self.assertEquals([1, 2, 3, 4, 5, 6], sorted(result)) + self.assertEqual([1, 2, 3, 4, 5, 6], sorted(result)) class TestCircularSerialization(TestCase): diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 77129b43faf..cbc6b35be34 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -432,19 +432,21 @@ class MultiplexerIterDataPipe(IterDataPipe): def __init__(self, *datapipes): self.datapipes = datapipes self.length: Optional[int] = None + self.buffer: List = [] # Store values to be yielded only when every iterator provides one def __iter__(self): iterators = [iter(x) for x in self.datapipes] while len(iterators): - values: List[Any] = [] for it in iterators: try: value = next(it) - values.append(value) + self.buffer.append(value) except StopIteration: + self.buffer.clear() return - for value in values: + for value in self.buffer: yield value + self.buffer.clear() def __len__(self): if self.length is not None: @@ -457,6 +459,29 @@ class MultiplexerIterDataPipe(IterDataPipe): self.length = -1 return len(self) + def reset(self) -> None: + self.buffer = [] + + def __getstate__(self): + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(self) + + state = ( + self.datapipes, + self.length, + ) + return state + + def __setstate__(self, state): + ( + self.datapipes, + self.length, + ) = state + self.buffer = [] + + def __del__(self): + self.buffer.clear() + @functional_datapipe('zip') class ZipperIterDataPipe(IterDataPipe[Tuple[T_co]]):