Use default_collate from public API (#143616)

Codemodded via `torchfix . --select=TOR104 --fix`.
This is a step to unblock https://github.com/pytorch/pytorch/pull/141076
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143616
Approved by: https://github.com/malfet
This commit is contained in:
Sergii Dymchenko 2024-12-23 17:38:43 +00:00 committed by PyTorch MergeBot
parent a70191da41
commit c042c8a475

View file

@ -49,6 +49,7 @@ from torch.utils.data import (
ChainDataset,
ConcatDataset,
DataLoader,
dataloader,
Dataset,
IterableDataset,
IterDataPipe,
@ -2758,51 +2759,51 @@ except RuntimeError as e:
def test_default_convert_mapping_keep_type(self):
data = CustomDict({"a": 1, "b": 2})
converted = _utils.collate.default_convert(data)
converted = dataloader.default_convert(data)
self.assertEqual(converted, data)
def test_default_convert_sequence_keep_type(self):
data = CustomList([1, 2, 3])
converted = _utils.collate.default_convert(data)
converted = dataloader.default_convert(data)
self.assertEqual(converted, data)
def test_default_convert_sequence_dont_keep_type(self):
data = range(2)
converted = _utils.collate.default_convert(data)
converted = dataloader.default_convert(data)
self.assertEqual(converted, [0, 1])
def test_default_collate_dtype(self):
arr = [1, 2, -1]
collated = _utils.collate.default_collate(arr)
collated = dataloader.default_collate(arr)
self.assertEqual(collated, torch.tensor(arr))
self.assertEqual(collated.dtype, torch.int64)
arr = [1.1, 2.3, -0.9]
collated = _utils.collate.default_collate(arr)
collated = dataloader.default_collate(arr)
self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64))
arr = [True, False]
collated = _utils.collate.default_collate(arr)
collated = dataloader.default_collate(arr)
self.assertEqual(collated, torch.tensor(arr))
self.assertEqual(collated.dtype, torch.bool)
# Should be a no-op
arr = ["a", "b", "c"]
self.assertEqual(arr, _utils.collate.default_collate(arr))
self.assertEqual(arr, dataloader.default_collate(arr))
def test_default_collate_mapping_keep_type(self):
batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})]
collated = _utils.collate.default_collate(batch)
collated = dataloader.default_collate(batch)
expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])})
self.assertEqual(collated, expected)
def test_default_collate_sequence_keep_type(self):
batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])]
collated = _utils.collate.default_collate(batch)
collated = dataloader.default_collate(batch)
expected = CustomList(
[
@ -2815,7 +2816,7 @@ except RuntimeError as e:
def test_default_collate_sequence_dont_keep_type(self):
batch = [range(2), range(2)]
collated = _utils.collate.default_collate(batch)
collated = dataloader.default_collate(batch)
self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])])
@ -2825,16 +2826,16 @@ except RuntimeError as e:
# Should be a no-op
arr = np.array(["a", "b", "c"])
self.assertEqual(arr, _utils.collate.default_collate(arr))
self.assertEqual(arr, dataloader.default_collate(arr))
arr = np.array([[["a", "b", "c"]]])
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
self.assertRaises(TypeError, lambda: dataloader.default_collate(arr))
arr = np.array([object(), object(), object()])
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
self.assertRaises(TypeError, lambda: dataloader.default_collate(arr))
arr = np.array([[[object(), object(), object()]]])
self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr))
self.assertRaises(TypeError, lambda: dataloader.default_collate(arr))
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_default_collate_numpy_memmap(self):
@ -2845,7 +2846,7 @@ except RuntimeError as e:
arr_memmap = np.memmap(f, dtype=arr.dtype, mode="w+", shape=arr.shape)
arr_memmap[:] = arr[:]
arr_new = np.memmap(f, dtype=arr.dtype, mode="r", shape=arr.shape)
tensor = _utils.collate.default_collate(list(arr_new))
tensor = dataloader.default_collate(list(arr_new))
self.assertTrue(
(tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item()
@ -2853,10 +2854,8 @@ except RuntimeError as e:
def test_default_collate_bad_sequence_type(self):
batch = [["X"], ["X", "X"]]
self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch))
self.assertRaises(
RuntimeError, lambda: _utils.collate.default_collate(batch[::-1])
)
self.assertRaises(RuntimeError, lambda: dataloader.default_collate(batch))
self.assertRaises(RuntimeError, lambda: dataloader.default_collate(batch[::-1]))
@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_default_collate_shared_tensor(self):
@ -2867,8 +2866,8 @@ except RuntimeError as e:
self.assertEqual(t_in.is_shared(), False)
self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False)
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False)
self.assertEqual(dataloader.default_collate([t_in]).is_shared(), False)
self.assertEqual(dataloader.default_collate([n_in]).is_shared(), False)
# FIXME: fix the following hack that makes `default_collate` believe
# that it is in a worker process (since it tests
@ -2876,8 +2875,8 @@ except RuntimeError as e:
old = _utils.worker._worker_info
try:
_utils.worker._worker_info = "x"
self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True)
self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True)
self.assertEqual(dataloader.default_collate([t_in]).is_shared(), True)
self.assertEqual(dataloader.default_collate([n_in]).is_shared(), True)
finally:
_utils.worker._worker_info = old