add device checks for sparse csr (#97520)

Fixes #95373

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97520
Approved by: https://github.com/cpuhrsch
This commit is contained in:
cybershiptrooper 2023-03-27 18:57:23 +00:00 committed by PyTorch MergeBot
parent 96e3b3ac72
commit bbc7c79b20

View file

@ -263,6 +263,18 @@ Tensor& add_out_sparse_csr_cuda(
self.sizes(),
" and tensor `other` with shape ",
other.sizes());
TORCH_CHECK(
self.is_cuda(),
"add: expected 'self' to be CUDA tensor, but got tensor on device: ",
self.device());
TORCH_CHECK(
other.is_cuda(),
"add: expected 'other' to be CUDA tensor, but got tensor on device: ",
other.device());
TORCH_CHECK(
out.is_cuda(),
"add: expected 'out' to be CUDA tensor, but got tensor on device: ",
out.device());
if (only_sparse_compressed_add_trivial_cases(self, other, alpha, out)) {
return out;