[DataPipe] Refactor 'mux' to have buffer as an instance variable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77775

Approved by: https://github.com/ejguan
This commit is contained in:
Kevin Tse 2022-05-19 11:26:07 -04:00 committed by PyTorch MergeBot
parent ba0ca0f591
commit b4a6730ce1
2 changed files with 30 additions and 5 deletions

View file

@ -2197,7 +2197,7 @@ class TestSerialization(TestCase):
dl = DataLoader(idp, num_workers=2, shuffle=True,
multiprocessing_context='spawn', collate_fn=unbatch, batch_size=1)
result = list(dl)
self.assertEquals([1, 1, 2, 2, 3, 3], sorted(result))
self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result))
@skipIfNoDill
def test_spawn_lambdas_map(self):
@ -2205,7 +2205,7 @@ class TestSerialization(TestCase):
dl = DataLoader(mdp, num_workers=2, shuffle=True,
multiprocessing_context='spawn', collate_fn=unbatch, batch_size=1)
result = list(dl)
self.assertEquals([1, 2, 3, 4, 5, 6], sorted(result))
self.assertEqual([1, 2, 3, 4, 5, 6], sorted(result))
class TestCircularSerialization(TestCase):

View file

@ -432,19 +432,21 @@ class MultiplexerIterDataPipe(IterDataPipe):
def __init__(self, *datapipes):
self.datapipes = datapipes
self.length: Optional[int] = None
self.buffer: List = [] # Store values to be yielded only when every iterator provides one
def __iter__(self):
iterators = [iter(x) for x in self.datapipes]
while len(iterators):
values: List[Any] = []
for it in iterators:
try:
value = next(it)
values.append(value)
self.buffer.append(value)
except StopIteration:
self.buffer.clear()
return
for value in values:
for value in self.buffer:
yield value
self.buffer.clear()
def __len__(self):
if self.length is not None:
@ -457,6 +459,29 @@ class MultiplexerIterDataPipe(IterDataPipe):
self.length = -1
return len(self)
def reset(self) -> None:
self.buffer = []
def __getstate__(self):
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(self)
state = (
self.datapipes,
self.length,
)
return state
def __setstate__(self, state):
(
self.datapipes,
self.length,
) = state
self.buffer = []
def __del__(self):
self.buffer.clear()
@functional_datapipe('zip')
class ZipperIterDataPipe(IterDataPipe[Tuple[T_co]]):