mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[DataPipe] Refactor 'mux' to have buffer as an instance variable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77775 Approved by: https://github.com/ejguan
This commit is contained in:
parent
ba0ca0f591
commit
b4a6730ce1
2 changed files with 30 additions and 5 deletions
|
|
@ -2197,7 +2197,7 @@ class TestSerialization(TestCase):
|
|||
dl = DataLoader(idp, num_workers=2, shuffle=True,
|
||||
multiprocessing_context='spawn', collate_fn=unbatch, batch_size=1)
|
||||
result = list(dl)
|
||||
self.assertEquals([1, 1, 2, 2, 3, 3], sorted(result))
|
||||
self.assertEqual([1, 1, 2, 2, 3, 3], sorted(result))
|
||||
|
||||
@skipIfNoDill
|
||||
def test_spawn_lambdas_map(self):
|
||||
|
|
@ -2205,7 +2205,7 @@ class TestSerialization(TestCase):
|
|||
dl = DataLoader(mdp, num_workers=2, shuffle=True,
|
||||
multiprocessing_context='spawn', collate_fn=unbatch, batch_size=1)
|
||||
result = list(dl)
|
||||
self.assertEquals([1, 2, 3, 4, 5, 6], sorted(result))
|
||||
self.assertEqual([1, 2, 3, 4, 5, 6], sorted(result))
|
||||
|
||||
|
||||
class TestCircularSerialization(TestCase):
|
||||
|
|
|
|||
|
|
@ -432,19 +432,21 @@ class MultiplexerIterDataPipe(IterDataPipe):
|
|||
def __init__(self, *datapipes):
|
||||
self.datapipes = datapipes
|
||||
self.length: Optional[int] = None
|
||||
self.buffer: List = [] # Store values to be yielded only when every iterator provides one
|
||||
|
||||
def __iter__(self):
|
||||
iterators = [iter(x) for x in self.datapipes]
|
||||
while len(iterators):
|
||||
values: List[Any] = []
|
||||
for it in iterators:
|
||||
try:
|
||||
value = next(it)
|
||||
values.append(value)
|
||||
self.buffer.append(value)
|
||||
except StopIteration:
|
||||
self.buffer.clear()
|
||||
return
|
||||
for value in values:
|
||||
for value in self.buffer:
|
||||
yield value
|
||||
self.buffer.clear()
|
||||
|
||||
def __len__(self):
|
||||
if self.length is not None:
|
||||
|
|
@ -457,6 +459,29 @@ class MultiplexerIterDataPipe(IterDataPipe):
|
|||
self.length = -1
|
||||
return len(self)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.buffer = []
|
||||
|
||||
def __getstate__(self):
|
||||
if IterDataPipe.getstate_hook is not None:
|
||||
return IterDataPipe.getstate_hook(self)
|
||||
|
||||
state = (
|
||||
self.datapipes,
|
||||
self.length,
|
||||
)
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
(
|
||||
self.datapipes,
|
||||
self.length,
|
||||
) = state
|
||||
self.buffer = []
|
||||
|
||||
def __del__(self):
|
||||
self.buffer.clear()
|
||||
|
||||
|
||||
@functional_datapipe('zip')
|
||||
class ZipperIterDataPipe(IterDataPipe[Tuple[T_co]]):
|
||||
|
|
|
|||
Loading…
Reference in a new issue