mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Use default_collate from public API (#143616)
Codemodded via `torchfix . --select=TOR104 --fix`. This is a step to unblock https://github.com/pytorch/pytorch/pull/141076 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143616 Approved by: https://github.com/malfet
This commit is contained in:
parent
a70191da41
commit
c042c8a475
1 changed files with 22 additions and 23 deletions
|
|
@ -49,6 +49,7 @@ from torch.utils.data import (
|
|||
ChainDataset,
|
||||
ConcatDataset,
|
||||
DataLoader,
|
||||
dataloader,
|
||||
Dataset,
|
||||
IterableDataset,
|
||||
IterDataPipe,
|
||||
|
|
@ -2758,51 +2759,51 @@ except RuntimeError as e:
|
|||
|
||||
def test_default_convert_mapping_keep_type(self):
|
||||
data = CustomDict({"a": 1, "b": 2})
|
||||
converted = _utils.collate.default_convert(data)
|
||||
converted = dataloader.default_convert(data)
|
||||
|
||||
self.assertEqual(converted, data)
|
||||
|
||||
def test_default_convert_sequence_keep_type(self):
|
||||
data = CustomList([1, 2, 3])
|
||||
converted = _utils.collate.default_convert(data)
|
||||
converted = dataloader.default_convert(data)
|
||||
|
||||
self.assertEqual(converted, data)
|
||||
|
||||
def test_default_convert_sequence_dont_keep_type(self):
|
||||
data = range(2)
|
||||
converted = _utils.collate.default_convert(data)
|
||||
converted = dataloader.default_convert(data)
|
||||
|
||||
self.assertEqual(converted, [0, 1])
|
||||
|
||||
def test_default_collate_dtype(self):
|
||||
arr = [1, 2, -1]
|
||||
collated = _utils.collate.default_collate(arr)
|
||||
collated = dataloader.default_collate(arr)
|
||||
self.assertEqual(collated, torch.tensor(arr))
|
||||
self.assertEqual(collated.dtype, torch.int64)
|
||||
|
||||
arr = [1.1, 2.3, -0.9]
|
||||
collated = _utils.collate.default_collate(arr)
|
||||
collated = dataloader.default_collate(arr)
|
||||
self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64))
|
||||
|
||||
arr = [True, False]
|
||||
collated = _utils.collate.default_collate(arr)
|
||||
collated = dataloader.default_collate(arr)
|
||||
self.assertEqual(collated, torch.tensor(arr))
|
||||
self.assertEqual(collated.dtype, torch.bool)
|
||||
|
||||
# Should be a no-op
|
||||
arr = ["a", "b", "c"]
|
||||
self.assertEqual(arr, _utils.collate.default_collate(arr))
|
||||
self.assertEqual(arr, dataloader.default_collate(arr))
|
||||
|
||||
def test_default_collate_mapping_keep_type(self):
|
||||
batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})]
|
||||
collated = _utils.collate.default_collate(batch)
|
||||
collated = dataloader.default_collate(batch)
|
||||
|
||||
expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])})
|
||||
self.assertEqual(collated, expected)
|
||||
|
||||
def test_default_collate_sequence_keep_type(self):
|
||||
batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])]
|
||||
collated = _utils.collate.default_collate(batch)
|
||||
collated = dataloader.default_collate(batch)
|
||||
|
||||
expected = CustomList(
|
||||
[
|
||||
|
|
@ -2815,7 +2816,7 @@ except RuntimeError as e:
|
|||
|
||||
def test_default_collate_sequence_dont_keep_type(self):
|
||||
batch = [range(2), range(2)]
|
||||
collated = _utils.collate.default_collate(batch)
|
||||
collated = dataloader.default_collate(batch)
|
||||
|
||||
self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])])
|
||||
|
||||
|
|
@ -2825,16 +2826,16 @@ except RuntimeError as e:
|
|||
|
||||
# Should be a no-op
|
||||
arr = np.array(["a", "b", "c"])
|
||||
self.assertEqual(arr, _utils.collate.default_collate(arr))
|
||||
self.assertEqual(arr, dataloader.default_collate(arr))
|
||||
|
||||
arr = np.array([[["a", "b", "c"]]])
|
||||
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
|
||||
self.assertRaises(TypeError, lambda: dataloader.default_collate(arr))
|
||||
|
||||
arr = np.array([object(), object(), object()])
|
||||
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
|
||||
self.assertRaises(TypeError, lambda: dataloader.default_collate(arr))
|
||||
|
||||
arr = np.array([[[object(), object(), object()]]])
|
||||
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
|
||||
self.assertRaises(TypeError, lambda: dataloader.default_collate(arr))
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
||||
def test_default_collate_numpy_memmap(self):
|
||||
|
|
@ -2845,7 +2846,7 @@ except RuntimeError as e:
|
|||
arr_memmap = np.memmap(f, dtype=arr.dtype, mode="w+", shape=arr.shape)
|
||||
arr_memmap[:] = arr[:]
|
||||
arr_new = np.memmap(f, dtype=arr.dtype, mode="r", shape=arr.shape)
|
||||
tensor = _utils.collate.default_collate(list(arr_new))
|
||||
tensor = dataloader.default_collate(list(arr_new))
|
||||
|
||||
self.assertTrue(
|
||||
(tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item()
|
||||
|
|
@ -2853,10 +2854,8 @@ except RuntimeError as e:
|
|||
|
||||
def test_default_collate_bad_sequence_type(self):
|
||||
batch = [["X"], ["X", "X"]]
|
||||
self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch))
|
||||
self.assertRaises(
|
||||
RuntimeError, lambda: _utils.collate.default_collate(batch[::-1])
|
||||
)
|
||||
self.assertRaises(RuntimeError, lambda: dataloader.default_collate(batch))
|
||||
self.assertRaises(RuntimeError, lambda: dataloader.default_collate(batch[::-1]))
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
|
||||
def test_default_collate_shared_tensor(self):
|
||||
|
|
@ -2867,8 +2866,8 @@ except RuntimeError as e:
|
|||
|
||||
self.assertEqual(t_in.is_shared(), False)
|
||||
|
||||
self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False)
|
||||
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False)
|
||||
self.assertEqual(dataloader.default_collate([t_in]).is_shared(), False)
|
||||
self.assertEqual(dataloader.default_collate([n_in]).is_shared(), False)
|
||||
|
||||
# FIXME: fix the following hack that makes `default_collate` believe
|
||||
# that it is in a worker process (since it tests
|
||||
|
|
@ -2876,8 +2875,8 @@ except RuntimeError as e:
|
|||
old = _utils.worker._worker_info
|
||||
try:
|
||||
_utils.worker._worker_info = "x"
|
||||
self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True)
|
||||
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True)
|
||||
self.assertEqual(dataloader.default_collate([t_in]).is_shared(), True)
|
||||
self.assertEqual(dataloader.default_collate([n_in]).is_shared(), True)
|
||||
finally:
|
||||
_utils.worker._worker_info = old
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue