Support torch.Tensor.to for CSR

Fixes #76379

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76400
Approved by: https://github.com/pearu, https://github.com/davidberard98
This commit is contained in:
Christian Puhrsch 2022-05-05 21:59:50 +00:00 committed by PyTorch MergeBot
parent 52af4fc5ba
commit ce9a477fdf
3 changed files with 86 additions and 5 deletions

View file

@ -69,6 +69,12 @@ SparseCsrTensorImpl::SparseCsrTensorImpl(
set_storage_access_should_throw();
is_non_overlapping_and_dense_ = false;
set_has_contiguity_policy(HasContiguityPolicy::ContiguityNotSupported);
// TODO: If this check ever shows up as a bottleneck, which is unlikely given that
// comparing devices only involves comparing the type and index (two integers), we
// can move this to a DEBUG only assert. Until then this confirms and maintains a
// crucial invariance.
TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and crow_indices need to be on the same device.");
TORCH_CHECK(values_.device() == col_indices_.device(), "Values and col_indices need to be on the same device.");
}
const char* SparseCsrTensorImpl::tensorimpl_type_name() const {
@ -134,6 +140,12 @@ void SparseCsrTensorImpl::set_member_tensors(
sizes_and_strides_.set_sizes(size);
refresh_numel();
// TODO: If this check ever shows up as a bottleneck, which is unlikely given that
// comparing devices only involves comparing the type and index (two integers), we
// can move this to a DEBUG only assert. Until then this confirms and maintains a
// crucial invariance.
TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and crow_indices need to be on the same device.");
TORCH_CHECK(values_.device() == col_indices_.device(), "Values and col_indices need to be on the same device.");
}
IntArrayRef SparseCsrTensorImpl::strides() const {

View file

@ -51,6 +51,54 @@ Tensor _to_copy(
// memory_format is handled separately due to MemoryFormat::Preserve logic
options = self.options().merge_in(options).memory_format(c10::nullopt);
auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
// TODO: Use the dispatcher for this.
// Currently there are unenumerated extensibility issues preventing this.
if (self.is_sparse_csr()) {
TORCH_CHECK(
memory_format == MemoryFormat::Preserve,
"sparse_csr only supports memory format Preserve, but got ",
memory_format,
" instead.");
auto new_values = at::native::to(
self.values(),
dtype,
c10::kStrided, // values are strided
device,
pin_memory,
non_blocking,
true, // force copy since we're in _to_copy
memory_format);
auto new_crow_indices = at::native::to(
self.crow_indices(),
self.crow_indices().scalar_type(), // indices are integral
c10::kStrided, // indices are strided
device,
pin_memory,
non_blocking,
true, // force copy since we're in _to_copy
memory_format);
auto new_col_indices = at::native::to(
self.col_indices(),
self.col_indices().scalar_type(), // indices are integral
c10::kStrided, // indices are strided
device,
pin_memory,
non_blocking,
true, // force copy since we're in _to_copy
memory_format);
return at::native::_sparse_csr_tensor_unsafe(
new_crow_indices,
new_col_indices,
new_values,
self.sizes(),
new_values.scalar_type(),
self.layout(),
new_values.device());
}
bool pin_out = (non_blocking && self.is_cuda() && options.device().is_cpu() &&
(options.layout() == c10::kStrided));

View file

@ -7549,7 +7549,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6)))
# FIXME: Port to a more appropriate test suite
def test_to(self):
def _test_to_with_layout(self, layout):
def test_copy_behavior(t, non_blocking=False):
self.assertIs(t, t.to(t, non_blocking=non_blocking))
self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
@ -7571,16 +7571,33 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True))
a = torch.tensor(5)
if layout == torch.sparse_csr:
a = torch.tensor([[0, 1, 2], [2, 0, 3]]).to_sparse_csr()
test_copy_behavior(a)
self.assertEqual(a.device, a.to('cpu').device)
self.assertEqual(a.device, a.to('cpu', dtype=torch.float32).device)
self.assertIs(torch.float32, a.to('cpu', dtype=torch.float32).dtype)
self.assertEqual(a.device, a.to(torch.float32).device)
self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype)
self.assertEqual(a.data_ptr(), a.to('cpu').data_ptr())
self.assertEqual(a.data_ptr(), a.to(dtype=a.dtype, device=a.device, copy=False).data_ptr())
self.assertEqual(a.data_ptr(), a.to('cpu', copy=False).data_ptr())
self.assertNotEqual(a.data_ptr(), a.to('cpu', copy=True).data_ptr())
def test_data_ptr(getter):
self.assertEqual(getter(a), getter(a.to('cpu')))
self.assertEqual(getter(a), getter(a.to(dtype=a.dtype, device=a.device, copy=False)))
self.assertEqual(getter(a), getter(a.to('cpu', copy=False)))
self.assertNotEqual(getter(a), getter(a.to('cpu', copy=True)))
if layout == torch.sparse_csr:
# TODO: compressed sparse tensors currently don't support data_ptr.
# Exercising failure will allow us to widen coverage of this test once it does.
with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer of Tensor that doesn't have storage"):
a.data_ptr()
# While compressed sparse tensors don't have a concept of data_ptr
# the underlying tensors do. The implementation of to appropriately forwards
# the call to the components, which is what we're test here.
test_data_ptr(lambda a: a.values().data_ptr())
test_data_ptr(lambda a: a.crow_indices().data_ptr())
test_data_ptr(lambda a: a.col_indices().data_ptr())
else:
test_data_ptr(lambda a: a.data_ptr())
if torch.cuda.is_available():
for non_blocking in [True, False]:
@ -7595,6 +7612,10 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype)
self.assertEqual(b.device, b.to(dtype=torch.int32).device)
def test_to(self):
self._test_to_with_layout(torch.strided)
self._test_to_with_layout(torch.sparse_csr)
# FIXME: describe this test
def test_as_subclass(self):
class SubTensor(torch.Tensor):