diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 46ee9ef6fa8..cf02325c14a 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -49,6 +49,7 @@ from torch.utils.data import ( ChainDataset, ConcatDataset, DataLoader, + dataloader, Dataset, IterableDataset, IterDataPipe, @@ -2758,51 +2759,51 @@ except RuntimeError as e: def test_default_convert_mapping_keep_type(self): data = CustomDict({"a": 1, "b": 2}) - converted = _utils.collate.default_convert(data) + converted = dataloader.default_convert(data) self.assertEqual(converted, data) def test_default_convert_sequence_keep_type(self): data = CustomList([1, 2, 3]) - converted = _utils.collate.default_convert(data) + converted = dataloader.default_convert(data) self.assertEqual(converted, data) def test_default_convert_sequence_dont_keep_type(self): data = range(2) - converted = _utils.collate.default_convert(data) + converted = dataloader.default_convert(data) self.assertEqual(converted, [0, 1]) def test_default_collate_dtype(self): arr = [1, 2, -1] - collated = _utils.collate.default_collate(arr) + collated = dataloader.default_collate(arr) self.assertEqual(collated, torch.tensor(arr)) self.assertEqual(collated.dtype, torch.int64) arr = [1.1, 2.3, -0.9] - collated = _utils.collate.default_collate(arr) + collated = dataloader.default_collate(arr) self.assertEqual(collated, torch.tensor(arr, dtype=torch.float64)) arr = [True, False] - collated = _utils.collate.default_collate(arr) + collated = dataloader.default_collate(arr) self.assertEqual(collated, torch.tensor(arr)) self.assertEqual(collated.dtype, torch.bool) # Should be a no-op arr = ["a", "b", "c"] - self.assertEqual(arr, _utils.collate.default_collate(arr)) + self.assertEqual(arr, dataloader.default_collate(arr)) def test_default_collate_mapping_keep_type(self): batch = [CustomDict({"a": 1, "b": 2}), CustomDict({"a": 3, "b": 4})] - collated = _utils.collate.default_collate(batch) + collated = dataloader.default_collate(batch) expected = CustomDict({"a": torch.tensor([1, 3]), "b": torch.tensor([2, 4])}) self.assertEqual(collated, expected) def test_default_collate_sequence_keep_type(self): batch = [CustomList([1, 2, 3]), CustomList([4, 5, 6])] - collated = _utils.collate.default_collate(batch) + collated = dataloader.default_collate(batch) expected = CustomList( [ @@ -2815,7 +2816,7 @@ except RuntimeError as e: def test_default_collate_sequence_dont_keep_type(self): batch = [range(2), range(2)] - collated = _utils.collate.default_collate(batch) + collated = dataloader.default_collate(batch) self.assertEqual(collated, [torch.tensor([0, 0]), torch.tensor([1, 1])]) @@ -2825,16 +2826,16 @@ except RuntimeError as e: # Should be a no-op arr = np.array(["a", "b", "c"]) - self.assertEqual(arr, _utils.collate.default_collate(arr)) + self.assertEqual(arr, dataloader.default_collate(arr)) arr = np.array([[["a", "b", "c"]]]) - self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) + self.assertRaises(TypeError, lambda: dataloader.default_collate(arr)) arr = np.array([object(), object(), object()]) - self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) + self.assertRaises(TypeError, lambda: dataloader.default_collate(arr)) arr = np.array([[[object(), object(), object()]]]) - self.assertRaises(TypeError, lambda: _utils.collate.default_collate(arr)) + self.assertRaises(TypeError, lambda: dataloader.default_collate(arr)) @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") def test_default_collate_numpy_memmap(self): @@ -2845,7 +2846,7 @@ except RuntimeError as e: arr_memmap = np.memmap(f, dtype=arr.dtype, mode="w+", shape=arr.shape) arr_memmap[:] = arr[:] arr_new = np.memmap(f, dtype=arr.dtype, mode="r", shape=arr.shape) - tensor = _utils.collate.default_collate(list(arr_new)) + tensor = dataloader.default_collate(list(arr_new)) self.assertTrue( (tensor == tensor.new_tensor([[0, 1], [2, 3], [4, 5], [6, 7]])).all().item() @@ -2853,10 +2854,8 @@ except RuntimeError as e: def test_default_collate_bad_sequence_type(self): batch = [["X"], ["X", "X"]] - self.assertRaises(RuntimeError, lambda: _utils.collate.default_collate(batch)) - self.assertRaises( - RuntimeError, lambda: _utils.collate.default_collate(batch[::-1]) - ) + self.assertRaises(RuntimeError, lambda: dataloader.default_collate(batch)) + self.assertRaises(RuntimeError, lambda: dataloader.default_collate(batch[::-1])) @unittest.skipIf(not TEST_NUMPY, "numpy unavailable") def test_default_collate_shared_tensor(self): @@ -2867,8 +2866,8 @@ except RuntimeError as e: self.assertEqual(t_in.is_shared(), False) - self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), False) - self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), False) + self.assertEqual(dataloader.default_collate([t_in]).is_shared(), False) + self.assertEqual(dataloader.default_collate([n_in]).is_shared(), False) # FIXME: fix the following hack that makes `default_collate` believe # that it is in a worker process (since it tests @@ -2876,8 +2875,8 @@ except RuntimeError as e: old = _utils.worker._worker_info try: _utils.worker._worker_info = "x" - self.assertEqual(_utils.collate.default_collate([t_in]).is_shared(), True) - self.assertEqual(_utils.collate.default_collate([n_in]).is_shared(), True) + self.assertEqual(dataloader.default_collate([t_in]).is_shared(), True) + self.assertEqual(dataloader.default_collate([n_in]).is_shared(), True) finally: _utils.worker._worker_info = old