diff --git a/aten/src/ATen/SparseCsrTensorImpl.cpp b/aten/src/ATen/SparseCsrTensorImpl.cpp index 5be4a5e9293..ebba1f70666 100644 --- a/aten/src/ATen/SparseCsrTensorImpl.cpp +++ b/aten/src/ATen/SparseCsrTensorImpl.cpp @@ -69,6 +69,12 @@ SparseCsrTensorImpl::SparseCsrTensorImpl( set_storage_access_should_throw(); is_non_overlapping_and_dense_ = false; set_has_contiguity_policy(HasContiguityPolicy::ContiguityNotSupported); + // TODO: If this check ever shows up as a bottleneck, which is unlikely given that + // comparing devices only involves comparing the type and index (two integers), we + // can move this to a DEBUG only assert. Until then this confirms and maintains a + // crucial invariance. + TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and crow_indices need to be on the same device."); + TORCH_CHECK(values_.device() == col_indices_.device(), "Values and col_indices need to be on the same device."); } const char* SparseCsrTensorImpl::tensorimpl_type_name() const { @@ -134,6 +140,12 @@ void SparseCsrTensorImpl::set_member_tensors( sizes_and_strides_.set_sizes(size); refresh_numel(); + // TODO: If this check ever shows up as a bottleneck, which is unlikely given that + // comparing devices only involves comparing the type and index (two integers), we + // can move this to a DEBUG only assert. Until then this confirms and maintains a + // crucial invariance. + TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and crow_indices need to be on the same device."); + TORCH_CHECK(values_.device() == col_indices_.device(), "Values and col_indices need to be on the same device."); } IntArrayRef SparseCsrTensorImpl::strides() const { diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index e63da81b4a5..2f1ee53c342 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -51,6 +51,54 @@ Tensor _to_copy( // memory_format is handled separately due to MemoryFormat::Preserve logic options = self.options().merge_in(options).memory_format(c10::nullopt); auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve); + // TODO: Use the dispatcher for this. + // Currently there are unenumerated extensibility issues preventing this. + if (self.is_sparse_csr()) { + TORCH_CHECK( + memory_format == MemoryFormat::Preserve, + "sparse_csr only supports memory format Preserve, but got ", + memory_format, + " instead."); + + auto new_values = at::native::to( + self.values(), + dtype, + c10::kStrided, // values are strided + device, + pin_memory, + non_blocking, + true, // force copy since we're in _to_copy + memory_format); + + auto new_crow_indices = at::native::to( + self.crow_indices(), + self.crow_indices().scalar_type(), // indices are integral + c10::kStrided, // indices are strided + device, + pin_memory, + non_blocking, + true, // force copy since we're in _to_copy + memory_format); + + auto new_col_indices = at::native::to( + self.col_indices(), + self.col_indices().scalar_type(), // indices are integral + c10::kStrided, // indices are strided + device, + pin_memory, + non_blocking, + true, // force copy since we're in _to_copy + memory_format); + + return at::native::_sparse_csr_tensor_unsafe( + new_crow_indices, + new_col_indices, + new_values, + self.sizes(), + new_values.scalar_type(), + self.layout(), + new_values.device()); + } bool pin_out = (non_blocking && self.is_cuda() && options.device().is_cpu() && (options.layout() == c10::kStrided)); diff --git a/test/test_torch.py b/test/test_torch.py index 5e9d21d3e38..dee4ed5e504 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7549,7 +7549,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6))) # FIXME: Port to a more appropriate test suite - def test_to(self): + def _test_to_with_layout(self, layout): def test_copy_behavior(t, non_blocking=False): self.assertIs(t, t.to(t, non_blocking=non_blocking)) self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking)) @@ -7571,16 +7571,33 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)) a = torch.tensor(5) + if layout == torch.sparse_csr: + a = torch.tensor([[0, 1, 2], [2, 0, 3]]).to_sparse_csr() test_copy_behavior(a) self.assertEqual(a.device, a.to('cpu').device) self.assertEqual(a.device, a.to('cpu', dtype=torch.float32).device) self.assertIs(torch.float32, a.to('cpu', dtype=torch.float32).dtype) self.assertEqual(a.device, a.to(torch.float32).device) self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype) - self.assertEqual(a.data_ptr(), a.to('cpu').data_ptr()) - self.assertEqual(a.data_ptr(), a.to(dtype=a.dtype, device=a.device, copy=False).data_ptr()) - self.assertEqual(a.data_ptr(), a.to('cpu', copy=False).data_ptr()) - self.assertNotEqual(a.data_ptr(), a.to('cpu', copy=True).data_ptr()) + + def test_data_ptr(getter): + self.assertEqual(getter(a), getter(a.to('cpu'))) + self.assertEqual(getter(a), getter(a.to(dtype=a.dtype, device=a.device, copy=False))) + self.assertEqual(getter(a), getter(a.to('cpu', copy=False))) + self.assertNotEqual(getter(a), getter(a.to('cpu', copy=True))) + if layout == torch.sparse_csr: + # TODO: compressed sparse tensors currently don't support data_ptr. + # Exercising failure will allow us to widen coverage of this test once it does. + with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer of Tensor that doesn't have storage"): + a.data_ptr() + # While compressed sparse tensors don't have a concept of data_ptr + # the underlying tensors do. The implementation of to appropriately forwards + # the call to the components, which is what we're test here. + test_data_ptr(lambda a: a.values().data_ptr()) + test_data_ptr(lambda a: a.crow_indices().data_ptr()) + test_data_ptr(lambda a: a.col_indices().data_ptr()) + else: + test_data_ptr(lambda a: a.data_ptr()) if torch.cuda.is_available(): for non_blocking in [True, False]: @@ -7595,6 +7612,10 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j], self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype) self.assertEqual(b.device, b.to(dtype=torch.int32).device) + def test_to(self): + self._test_to_with_layout(torch.strided) + self._test_to_with_layout(torch.sparse_csr) + # FIXME: describe this test def test_as_subclass(self): class SubTensor(torch.Tensor):