mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Support torch.Tensor.to for CSR
Fixes #76379 Pull Request resolved: https://github.com/pytorch/pytorch/pull/76400 Approved by: https://github.com/pearu, https://github.com/davidberard98
This commit is contained in:
parent
52af4fc5ba
commit
ce9a477fdf
3 changed files with 86 additions and 5 deletions
|
|
@ -69,6 +69,12 @@ SparseCsrTensorImpl::SparseCsrTensorImpl(
|
|||
set_storage_access_should_throw();
|
||||
is_non_overlapping_and_dense_ = false;
|
||||
set_has_contiguity_policy(HasContiguityPolicy::ContiguityNotSupported);
|
||||
// TODO: If this check ever shows up as a bottleneck, which is unlikely given that
|
||||
// comparing devices only involves comparing the type and index (two integers), we
|
||||
// can move this to a DEBUG only assert. Until then this confirms and maintains a
|
||||
// crucial invariance.
|
||||
TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and crow_indices need to be on the same device.");
|
||||
TORCH_CHECK(values_.device() == col_indices_.device(), "Values and col_indices need to be on the same device.");
|
||||
}
|
||||
|
||||
const char* SparseCsrTensorImpl::tensorimpl_type_name() const {
|
||||
|
|
@ -134,6 +140,12 @@ void SparseCsrTensorImpl::set_member_tensors(
|
|||
|
||||
sizes_and_strides_.set_sizes(size);
|
||||
refresh_numel();
|
||||
// TODO: If this check ever shows up as a bottleneck, which is unlikely given that
|
||||
// comparing devices only involves comparing the type and index (two integers), we
|
||||
// can move this to a DEBUG only assert. Until then this confirms and maintains a
|
||||
// crucial invariance.
|
||||
TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and crow_indices need to be on the same device.");
|
||||
TORCH_CHECK(values_.device() == col_indices_.device(), "Values and col_indices need to be on the same device.");
|
||||
}
|
||||
|
||||
IntArrayRef SparseCsrTensorImpl::strides() const {
|
||||
|
|
|
|||
|
|
@ -51,6 +51,54 @@ Tensor _to_copy(
|
|||
// memory_format is handled separately due to MemoryFormat::Preserve logic
|
||||
options = self.options().merge_in(options).memory_format(c10::nullopt);
|
||||
auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve);
|
||||
// TODO: Use the dispatcher for this.
|
||||
// Currently there are unenumerated extensibility issues preventing this.
|
||||
if (self.is_sparse_csr()) {
|
||||
TORCH_CHECK(
|
||||
memory_format == MemoryFormat::Preserve,
|
||||
"sparse_csr only supports memory format Preserve, but got ",
|
||||
memory_format,
|
||||
" instead.");
|
||||
|
||||
auto new_values = at::native::to(
|
||||
self.values(),
|
||||
dtype,
|
||||
c10::kStrided, // values are strided
|
||||
device,
|
||||
pin_memory,
|
||||
non_blocking,
|
||||
true, // force copy since we're in _to_copy
|
||||
memory_format);
|
||||
|
||||
auto new_crow_indices = at::native::to(
|
||||
self.crow_indices(),
|
||||
self.crow_indices().scalar_type(), // indices are integral
|
||||
c10::kStrided, // indices are strided
|
||||
device,
|
||||
pin_memory,
|
||||
non_blocking,
|
||||
true, // force copy since we're in _to_copy
|
||||
memory_format);
|
||||
|
||||
auto new_col_indices = at::native::to(
|
||||
self.col_indices(),
|
||||
self.col_indices().scalar_type(), // indices are integral
|
||||
c10::kStrided, // indices are strided
|
||||
device,
|
||||
pin_memory,
|
||||
non_blocking,
|
||||
true, // force copy since we're in _to_copy
|
||||
memory_format);
|
||||
|
||||
return at::native::_sparse_csr_tensor_unsafe(
|
||||
new_crow_indices,
|
||||
new_col_indices,
|
||||
new_values,
|
||||
self.sizes(),
|
||||
new_values.scalar_type(),
|
||||
self.layout(),
|
||||
new_values.device());
|
||||
}
|
||||
|
||||
bool pin_out = (non_blocking && self.is_cuda() && options.device().is_cpu() &&
|
||||
(options.layout() == c10::kStrided));
|
||||
|
|
|
|||
|
|
@ -7549,7 +7549,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6)))
|
||||
|
||||
# FIXME: Port to a more appropriate test suite
|
||||
def test_to(self):
|
||||
def _test_to_with_layout(self, layout):
|
||||
def test_copy_behavior(t, non_blocking=False):
|
||||
self.assertIs(t, t.to(t, non_blocking=non_blocking))
|
||||
self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
|
||||
|
|
@ -7571,16 +7571,33 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True))
|
||||
|
||||
a = torch.tensor(5)
|
||||
if layout == torch.sparse_csr:
|
||||
a = torch.tensor([[0, 1, 2], [2, 0, 3]]).to_sparse_csr()
|
||||
test_copy_behavior(a)
|
||||
self.assertEqual(a.device, a.to('cpu').device)
|
||||
self.assertEqual(a.device, a.to('cpu', dtype=torch.float32).device)
|
||||
self.assertIs(torch.float32, a.to('cpu', dtype=torch.float32).dtype)
|
||||
self.assertEqual(a.device, a.to(torch.float32).device)
|
||||
self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype)
|
||||
self.assertEqual(a.data_ptr(), a.to('cpu').data_ptr())
|
||||
self.assertEqual(a.data_ptr(), a.to(dtype=a.dtype, device=a.device, copy=False).data_ptr())
|
||||
self.assertEqual(a.data_ptr(), a.to('cpu', copy=False).data_ptr())
|
||||
self.assertNotEqual(a.data_ptr(), a.to('cpu', copy=True).data_ptr())
|
||||
|
||||
def test_data_ptr(getter):
|
||||
self.assertEqual(getter(a), getter(a.to('cpu')))
|
||||
self.assertEqual(getter(a), getter(a.to(dtype=a.dtype, device=a.device, copy=False)))
|
||||
self.assertEqual(getter(a), getter(a.to('cpu', copy=False)))
|
||||
self.assertNotEqual(getter(a), getter(a.to('cpu', copy=True)))
|
||||
if layout == torch.sparse_csr:
|
||||
# TODO: compressed sparse tensors currently don't support data_ptr.
|
||||
# Exercising failure will allow us to widen coverage of this test once it does.
|
||||
with self.assertRaisesRegex(RuntimeError, "Cannot access data pointer of Tensor that doesn't have storage"):
|
||||
a.data_ptr()
|
||||
# While compressed sparse tensors don't have a concept of data_ptr
|
||||
# the underlying tensors do. The implementation of to appropriately forwards
|
||||
# the call to the components, which is what we're test here.
|
||||
test_data_ptr(lambda a: a.values().data_ptr())
|
||||
test_data_ptr(lambda a: a.crow_indices().data_ptr())
|
||||
test_data_ptr(lambda a: a.col_indices().data_ptr())
|
||||
else:
|
||||
test_data_ptr(lambda a: a.data_ptr())
|
||||
|
||||
if torch.cuda.is_available():
|
||||
for non_blocking in [True, False]:
|
||||
|
|
@ -7595,6 +7612,10 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
|||
self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype)
|
||||
self.assertEqual(b.device, b.to(dtype=torch.int32).device)
|
||||
|
||||
def test_to(self):
|
||||
self._test_to_with_layout(torch.strided)
|
||||
self._test_to_with_layout(torch.sparse_csr)
|
||||
|
||||
# FIXME: describe this test
|
||||
def test_as_subclass(self):
|
||||
class SubTensor(torch.Tensor):
|
||||
|
|
|
|||
Loading…
Reference in a new issue