mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Underscore prefix sparse_csr_tensor and to_sparse_csr (#59001)
* Underscore prefix sparse_csr_tensor and to_sparse_csr Signed-off-by: Edward Z. Yang <ezyang@fb.com> * fix lint Signed-off-by: Edward Z. Yang <ezyang@fb.com>
This commit is contained in:
parent
b5e2635281
commit
dfc58f4faa
16 changed files with 80 additions and 80 deletions
|
|
@ -4805,9 +4805,9 @@
|
|||
# FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given
|
||||
# the default would never make sense.
|
||||
|
||||
- func: sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
- func: _sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
|
||||
- func: sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
- func: _sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
|
||||
- func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ SparseCsrTensor new_csr_tensor(const TensorOptions& options) {
|
|||
// TODO: This constructor should probably use an ATen abstract method in order
|
||||
// to make autograd dispatch available for the CSR constructor. See the relevant
|
||||
// note in native_functions.yaml.
|
||||
Tensor sparse_csr_tensor(
|
||||
Tensor _sparse_csr_tensor(
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values,
|
||||
|
|
@ -86,7 +86,7 @@ Tensor sparse_csr_tensor(
|
|||
return self;
|
||||
}
|
||||
|
||||
Tensor sparse_csr_tensor(
|
||||
Tensor _sparse_csr_tensor(
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& col_indices,
|
||||
const Tensor& values,
|
||||
|
|
@ -125,7 +125,7 @@ Tensor sparse_csr_tensor(
|
|||
size[1] = 0;
|
||||
}
|
||||
|
||||
return at::sparse_csr_tensor(
|
||||
return at::_sparse_csr_tensor(
|
||||
crow_indices, col_indices, values, size, options);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def gen_sparse_csr(shape, nnz):
|
|||
dense[f] = fill_value
|
||||
dense = torch.from_numpy(dense.reshape(shape))
|
||||
|
||||
return dense.to_sparse_csr()
|
||||
return dense._to_sparse_csr()
|
||||
|
||||
def gen_sparse_coo(shape, nnz):
|
||||
dense = np.random.randn(*shape)
|
||||
|
|
@ -51,4 +51,4 @@ def gen_sparse_coo_and_csr(shape, nnz):
|
|||
dense[f] = 0
|
||||
|
||||
dense = torch.from_numpy(dense.reshape(shape))
|
||||
return dense.to_sparse(), dense.to_sparse_csr()
|
||||
return dense.to_sparse(), dense._to_sparse_csr()
|
||||
|
|
|
|||
|
|
@ -397,7 +397,7 @@ and ``values``:
|
|||
Construction of CSR tensors
|
||||
---------------------------
|
||||
|
||||
Sparse CSR matrices can be directly constructed by using the :func:`torch.sparse_csr_tensor`
|
||||
Sparse CSR matrices can be directly constructed by using the :func:`torch._sparse_csr_tensor`
|
||||
method. The user must supply the row and column indices and values tensors separately.
|
||||
The ``size`` argument is optional and will be deduced from the the ``crow_indices``
|
||||
and ``col_indices`` if it is not present.
|
||||
|
|
@ -405,7 +405,7 @@ and ``col_indices`` if it is not present.
|
|||
>>> crow_indices = torch.tensor([0, 2, 4])
|
||||
>>> col_indices = torch.tensor([0, 1, 0, 1])
|
||||
>>> values = torch.tensor([1, 2, 3, 4])
|
||||
>>> csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.double)
|
||||
>>> csr = torch._sparse_csr_tensor(crow_indices, col_indices, values, dtype=torch.double)
|
||||
>>> csr
|
||||
tensor(crow_indices=tensor([0, 2, 4]),
|
||||
col_indices=tensor([0, 1, 0, 1]),
|
||||
|
|
@ -419,11 +419,11 @@ CSR Tensor Operations
|
|||
---------------------
|
||||
|
||||
The simplest way of constructing a sparse CSR tensor from a strided or sparse COO
|
||||
tensor is to use :meth:`tensor.to_sparse_csr`. Any zeros in the (strided) tensor will
|
||||
tensor is to use :meth:`tensor._to_sparse_csr`. Any zeros in the (strided) tensor will
|
||||
be interpreted as missing values in the sparse tensor:
|
||||
|
||||
>>> a = torch.tensor([[0, 0, 1, 0], [1, 2, 0, 0], [0, 0, 0, 0]], dtype = torch.float64)
|
||||
>>> sp = a.to_sparse_csr()
|
||||
>>> sp = a._to_sparse_csr()
|
||||
>>> sp
|
||||
tensor(crow_indices=tensor([0, 1, 3, 3]),
|
||||
col_indices=tensor([2, 0, 1]),
|
||||
|
|
@ -496,7 +496,7 @@ The following Tensor methods are related to sparse tensors:
|
|||
Tensor.sparse_dim
|
||||
Tensor.sparse_mask
|
||||
Tensor.to_sparse
|
||||
Tensor.to_sparse_csr
|
||||
Tensor._to_sparse_csr
|
||||
Tensor.indices
|
||||
Tensor.values
|
||||
|
||||
|
|
@ -581,7 +581,7 @@ Torch functions specific to sparse Tensors
|
|||
:nosignatures:
|
||||
|
||||
sparse_coo_tensor
|
||||
sparse_csr_tensor
|
||||
_sparse_csr_tensor
|
||||
sparse.sum
|
||||
sparse.addmm
|
||||
sparse.mm
|
||||
|
|
|
|||
|
|
@ -1699,10 +1699,10 @@ graph(%Ra, %Rb):
|
|||
self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
|
||||
|
||||
@suppress_warnings
|
||||
def test_sparse_csr_tensors(self):
|
||||
def test__sparse_csr_tensors(self):
|
||||
@torch.jit.ignore
|
||||
def get_sparse_csr():
|
||||
return torch.randn(3, 3).to_sparse_csr()
|
||||
return torch.randn(3, 3)._to_sparse_csr()
|
||||
|
||||
@torch.jit.script
|
||||
def test_is_sparse_csr(input):
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ class TestSparseCSR(TestCase):
|
|||
crow_indices = [0, 2, 4]
|
||||
col_indices = [0, 1, 0, 1]
|
||||
values = [1, 2, 3, 4]
|
||||
sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),
|
||||
torch.tensor(col_indices, dtype=torch.int64),
|
||||
torch.tensor(values), dtype=dtype, device=device)
|
||||
sparse = torch._sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),
|
||||
torch.tensor(col_indices, dtype=torch.int64),
|
||||
torch.tensor(values), dtype=dtype, device=device)
|
||||
self.assertEqual(torch.tensor(crow_indices, dtype=torch.int64), sparse.crow_indices())
|
||||
self.assertEqual((len(crow_indices) - 1, max(col_indices) + 1), sparse.shape)
|
||||
self.assertEqual(dtype, sparse.dtype)
|
||||
|
|
@ -37,12 +37,12 @@ class TestSparseCSR(TestCase):
|
|||
col_indices = [0, 1, 0, 1]
|
||||
values = [1, 2, 3, 4]
|
||||
for index_dtype in [torch.int32, torch.int64]:
|
||||
sparse = torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=index_dtype),
|
||||
torch.tensor(col_indices, dtype=index_dtype),
|
||||
torch.tensor(values),
|
||||
size=(2, 10),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
sparse = torch._sparse_csr_tensor(torch.tensor(crow_indices, dtype=index_dtype),
|
||||
torch.tensor(col_indices, dtype=index_dtype),
|
||||
torch.tensor(values),
|
||||
size=(2, 10),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
self.assertEqual((2, 10), sparse.shape)
|
||||
self.assertEqual(torch.tensor(crow_indices, dtype=index_dtype), sparse.crow_indices())
|
||||
self.assertEqual(torch.tensor(col_indices, dtype=index_dtype), sparse.col_indices())
|
||||
|
|
@ -55,27 +55,27 @@ class TestSparseCSR(TestCase):
|
|||
col_indices = [0, 1, 0, 1]
|
||||
values = [1, 2, 3, 4]
|
||||
size = (2, 10)
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), size,
|
||||
dtype=dtype, device=device)
|
||||
torch._sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), size,
|
||||
dtype=dtype, device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"crow_indices\.numel\(\) must be size\(0\) \+ 1, but got: 3"):
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), (1, 1),
|
||||
dtype=dtype, device=device)
|
||||
torch._sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), (1, 1),
|
||||
dtype=dtype, device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "0th value of crow_indices must be 0"):
|
||||
torch.sparse_csr_tensor(torch.tensor([-1, -1, -1]), torch.tensor(col_indices), torch.tensor(values), size,
|
||||
dtype=dtype, device=device)
|
||||
torch._sparse_csr_tensor(torch.tensor([-1, -1, -1]), torch.tensor(col_indices), torch.tensor(values), size,
|
||||
dtype=dtype, device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "last value of crow_indices should be less than length of col_indices."):
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 0, 0]), torch.tensor(values), size,
|
||||
dtype=dtype, device=device)
|
||||
torch._sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 0, 0]), torch.tensor(values), size,
|
||||
dtype=dtype, device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"col_indices and values must have equal sizes, " +
|
||||
r"but got col_indices\.size\(0\): 4, values\.size\(0\): 5"):
|
||||
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor([0, 0, 0, 0, 0]),
|
||||
size, dtype=dtype, device=device)
|
||||
torch._sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor([0, 0, 0, 0, 0]),
|
||||
size, dtype=dtype, device=device)
|
||||
|
||||
@onlyCPU
|
||||
def test_sparse_csr_print(self, device):
|
||||
|
|
@ -113,19 +113,19 @@ class TestSparseCSR(TestCase):
|
|||
@onlyCPU
|
||||
def test_sparse_csr_from_dense(self, device):
|
||||
dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]], device=device)
|
||||
sparse = dense.to_sparse_csr()
|
||||
sparse = dense._to_sparse_csr()
|
||||
self.assertEqual(torch.tensor([0, 2, 2, 3], dtype=torch.int64), sparse.crow_indices())
|
||||
self.assertEqual(torch.tensor([0, 1, 0], dtype=torch.int64), sparse.col_indices())
|
||||
self.assertEqual(torch.tensor([4, 5, 1]), sparse.values())
|
||||
|
||||
dense = torch.tensor([[0, 0, 0], [0, 0, 1], [1, 0, 0]], device=device)
|
||||
sparse = dense.to_sparse_csr()
|
||||
sparse = dense._to_sparse_csr()
|
||||
self.assertEqual(torch.tensor([0, 0, 1, 2], dtype=torch.int64), sparse.crow_indices())
|
||||
self.assertEqual(torch.tensor([2, 0], dtype=torch.int64), sparse.col_indices())
|
||||
self.assertEqual(torch.tensor([1, 1]), sparse.values())
|
||||
|
||||
dense = torch.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 2]], device=device)
|
||||
sparse = dense.to_sparse_csr()
|
||||
sparse = dense._to_sparse_csr()
|
||||
self.assertEqual(torch.tensor([0, 3, 6, 9], dtype=torch.int64), sparse.crow_indices())
|
||||
self.assertEqual(torch.tensor([0, 1, 2] * 3, dtype=torch.int64), sparse.col_indices())
|
||||
self.assertEqual(torch.tensor([2] * 9), sparse.values())
|
||||
|
|
@ -135,19 +135,19 @@ class TestSparseCSR(TestCase):
|
|||
def test_dense_convert(self, device, dtype):
|
||||
size = (5, 5)
|
||||
dense = torch.randn(size, dtype=dtype, device=device)
|
||||
sparse = dense.to_sparse_csr()
|
||||
sparse = dense._to_sparse_csr()
|
||||
self.assertEqual(sparse.to_dense(), dense)
|
||||
|
||||
size = (4, 6)
|
||||
dense = torch.randn(size, dtype=dtype, device=device)
|
||||
sparse = dense.to_sparse_csr()
|
||||
sparse = dense._to_sparse_csr()
|
||||
self.assertEqual(sparse.to_dense(), dense)
|
||||
|
||||
crow_indices = torch.tensor([0, 3, 5])
|
||||
col_indices = torch.tensor([0, 1, 2, 0, 1])
|
||||
values = torch.tensor([1, 2, 1, 3, 4], dtype=dtype)
|
||||
csr = torch.sparse_csr_tensor(crow_indices, col_indices,
|
||||
values, dtype=dtype, device=device)
|
||||
csr = torch._sparse_csr_tensor(crow_indices, col_indices,
|
||||
values, dtype=dtype, device=device)
|
||||
dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device)
|
||||
self.assertEqual(csr.to_dense(), dense)
|
||||
|
||||
|
|
@ -159,7 +159,7 @@ class TestSparseCSR(TestCase):
|
|||
sparse_dim = 2
|
||||
nnz = 10
|
||||
sparse_coo, _, _ = self.genSparseTensor(size, sparse_dim, nnz, coalesced, device, dtype)
|
||||
sparse_csr = sparse_coo.to_sparse_csr()
|
||||
sparse_csr = sparse_coo._to_sparse_csr()
|
||||
|
||||
self.assertTrue(sparse_csr.is_sparse_csr)
|
||||
self.assertEqual(sparse_csr.to_dense(), sparse_coo.to_dense())
|
||||
|
|
@ -177,7 +177,7 @@ class TestSparseCSR(TestCase):
|
|||
], dtype=torch.int32)
|
||||
values = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype, device=device)
|
||||
coo = torch.sparse_coo_tensor(index, values, torch.Size([100, 100]), dtype=dtype, device=device)
|
||||
csr = coo.to_sparse_csr()
|
||||
csr = coo._to_sparse_csr()
|
||||
|
||||
self.assertEqual(coo.matmul(vec), csr.matmul(vec))
|
||||
|
||||
|
|
@ -186,9 +186,9 @@ class TestSparseCSR(TestCase):
|
|||
def test_mkl_matvec_warnings(self, device, dtype):
|
||||
if torch.has_mkl:
|
||||
for index_dtype in [torch.int32, torch.int64]:
|
||||
sp = torch.sparse_csr_tensor(torch.tensor([0, 2, 4]),
|
||||
torch.tensor([0, 1, 0, 1]),
|
||||
torch.tensor([1, 2, 3, 4], dtype=dtype, device=device))
|
||||
sp = torch._sparse_csr_tensor(torch.tensor([0, 2, 4]),
|
||||
torch.tensor([0, 1, 0, 1]),
|
||||
torch.tensor([1, 2, 3, 4], dtype=dtype, device=device))
|
||||
vec = torch.randn((2, 1), dtype=dtype, device=device)
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
sp.matmul(vec)
|
||||
|
|
@ -204,7 +204,7 @@ class TestSparseCSR(TestCase):
|
|||
dense = torch.randn(size, dtype=dtype, device=device)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Only 2D"):
|
||||
sparse = dense.to_sparse_csr()
|
||||
sparse = dense._to_sparse_csr()
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float, torch.double)
|
||||
|
|
@ -229,7 +229,7 @@ class TestSparseCSR(TestCase):
|
|||
size = (5, 5)
|
||||
dense = torch.randn(size, dtype=dtype, device=device)
|
||||
coo_sparse = dense.to_sparse()
|
||||
csr_sparse = coo_sparse.to_sparse_csr()
|
||||
csr_sparse = coo_sparse._to_sparse_csr()
|
||||
|
||||
self.assertEqual(csr_sparse.to_dense(), dense)
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ SKIP_PYTHON_BINDINGS = [
|
|||
'alias', 'contiguous', 'is_cuda', 'is_sparse', 'is_sparse_csr', 'size', 'stride',
|
||||
'.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
|
||||
'.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
|
||||
'_?sparse_csr_tensor.*',
|
||||
'_?_sparse_csr_tensor.*',
|
||||
'_arange.*', '_range.*', '_linspace.*', '_logspace.*',
|
||||
'_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*', '_sparse_dense_add_out',
|
||||
'index', 'unique_dim_consecutive',
|
||||
|
|
|
|||
|
|
@ -406,11 +406,11 @@ static std::vector<Tensor> dispatch_nonzero_numpy(const Tensor & self) {
|
|||
|
||||
static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs);
|
||||
|
||||
static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
static PyObject * THPVariable__sparse_csr_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
jit::tracer::warn("torch.sparse_csr_tensor", jit::tracer::WARN_CONSTRUCTOR);
|
||||
return THPVariable_Wrap(torch::utils::sparse_csr_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
|
||||
jit::tracer::warn("torch._sparse_csr_tensor", jit::tracer::WARN_CONSTRUCTOR);
|
||||
return THPVariable_Wrap(torch::utils::_sparse_csr_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
|
@ -493,7 +493,7 @@ static PyMethodDef torch_functions[] = {
|
|||
{"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"saddmm", castPyCFunctionWithKeywords(THPVariable_sspaddmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"_sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable__sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"_validate_sparse_coo_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_coo_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
{"spmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
|
||||
|
|
|
|||
|
|
@ -291,10 +291,10 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
|
|||
'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
|
||||
' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
|
||||
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
|
||||
'sparse_csr_tensor' : ['def sparse_csr_tensor(crow_indices: Tensor, col_indices: Tensor,'
|
||||
' values: Tensor, size: Optional[_size]=None,'
|
||||
' *, dtype: Optional[_dtype]=None,'
|
||||
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
|
||||
'_sparse_csr_tensor' : ['def _sparse_csr_tensor(crow_indices: Tensor, col_indices: Tensor,'
|
||||
' values: Tensor, size: Optional[_size]=None,'
|
||||
' *, dtype: Optional[_dtype]=None,'
|
||||
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
|
||||
'_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],'
|
||||
' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
|
||||
' requires_grad: bool = False) -> Tensor: ...'],
|
||||
|
|
|
|||
|
|
@ -920,13 +920,13 @@ class Tensor(torch._C._TensorBase):
|
|||
# See Note [rename_ / rename API]
|
||||
return update_names(self, names, rename_map, inplace=False)
|
||||
|
||||
def to_sparse_csr(self):
|
||||
def _to_sparse_csr(self):
|
||||
""" Convert a tensor to compressed row storage format. Only works with 2D tensors.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> dense = torch.randn(5, 5)
|
||||
>>> sparse = dense.to_sparse_csr()
|
||||
>>> sparse = dense._to_sparse_csr()
|
||||
>>> sparse._nnz()
|
||||
25
|
||||
|
||||
|
|
@ -946,13 +946,13 @@ class Tensor(torch._C._TensorBase):
|
|||
i += 1
|
||||
ro.append(i)
|
||||
|
||||
return torch.sparse_csr_tensor(torch.tensor(ro, dtype=row_indices.dtype),
|
||||
coalesced_self.indices()[1], coalesced_self.values(),
|
||||
size=coalesced_self.shape, dtype=coalesced_self.dtype)
|
||||
return torch._sparse_csr_tensor(torch.tensor(ro, dtype=row_indices.dtype),
|
||||
coalesced_self.indices()[1], coalesced_self.values(),
|
||||
size=coalesced_self.shape, dtype=coalesced_self.dtype)
|
||||
elif self.is_sparse_csr:
|
||||
return self
|
||||
else:
|
||||
return self.to_sparse().to_sparse_csr()
|
||||
return self.to_sparse()._to_sparse_csr()
|
||||
|
||||
def _update_names(self, names, inplace):
|
||||
if has_torch_function_unary(self):
|
||||
|
|
|
|||
|
|
@ -4777,7 +4777,7 @@ matrix multiplication, it is necessary to use ``int32`` indexing in order
|
|||
to avoid downcasting and potentially losing information.
|
||||
|
||||
Example::
|
||||
>>> csr = torch.eye(5,5).to_sparse_csr()
|
||||
>>> csr = torch.eye(5,5)._to_sparse_csr()
|
||||
>>> csr.crow_indices()
|
||||
tensor([0, 1, 2, 3, 4, 5], dtype=torch.int32)
|
||||
|
||||
|
|
@ -4795,7 +4795,7 @@ matrix multiplication, it is necessary to use ``int32`` indexing in order
|
|||
to avoid downcasting and potentially losing information.
|
||||
|
||||
Example::
|
||||
>>> csr = torch.eye(5,5).to_sparse_csr()
|
||||
>>> csr = torch.eye(5,5)._to_sparse_csr()
|
||||
>>> csr.col_indices()
|
||||
tensor([0, 1, 2, 3, 4], dtype=torch.int32)
|
||||
|
||||
|
|
|
|||
|
|
@ -8156,9 +8156,9 @@ Example::
|
|||
[-0.0881, 0.4370, 0.2275, 1.0284]])
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.sparse_csr_tensor,
|
||||
add_docstr(torch._sparse_csr_tensor,
|
||||
r"""
|
||||
sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
|
||||
_sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
|
||||
|
||||
Constructs a :ref:`sparse tensor in CSR (Compressed Sparse Row) <sparse-csr-docs>` with specified
|
||||
values at the given :attr:`crow_indices` and :attr:`col_indices`. Sparse matrix multiplication operations
|
||||
|
|
@ -8190,7 +8190,7 @@ Example ::
|
|||
>>> crow_indices = [0, 2, 4]
|
||||
>>> col_indices = [0, 1, 0, 1]
|
||||
>>> values = [1, 2, 3, 4]
|
||||
>>> torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),
|
||||
>>> torch._sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),
|
||||
... torch.tensor(col_indices, dtype=torch.int64),
|
||||
... torch.tensor(values), dtype=torch.double)
|
||||
tensor(crow_indices=tensor([0, 2, 4]),
|
||||
|
|
|
|||
|
|
@ -603,11 +603,11 @@ Tensor indexing_tensor_from_data(
|
|||
}
|
||||
}
|
||||
|
||||
Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
|
||||
Tensor _sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
|
||||
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
|
||||
static PythonArgParser parser({
|
||||
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
|
||||
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
|
||||
"_sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
|
||||
"_sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
|
||||
});
|
||||
const int NUM_ARGS = 9, CROW_INDICES_ARG = 0, COL_INDICES_ARG = 1, VALUES_ARG = 2;
|
||||
ParsedArgs<NUM_ARGS> parsed_args;
|
||||
|
|
@ -638,7 +638,7 @@ Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scal
|
|||
/*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/false);
|
||||
|
||||
return at::sparse_csr_tensor(crow_indices, col_indices, values, r.intlist(SIZE_ARRAY_ARG),
|
||||
return at::_sparse_csr_tensor(crow_indices, col_indices, values, r.intlist(SIZE_ARRAY_ARG),
|
||||
values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(REQ_GRAD_ARG));
|
||||
} else if (r.idx == 1) {
|
||||
const int TYPE_INFERENCE_ARG = 3, DEVICE_TYPE_ARG = 5, REQ_GRAD_ARG = 7;
|
||||
|
|
@ -657,10 +657,10 @@ Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scal
|
|||
Tensor col_indices = internal_new_from_data(values.options(), col_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG),
|
||||
r.pyobject(COL_INDICES_ARG), /*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/false);
|
||||
return at::sparse_csr_tensor(crow_indices, col_indices, values,
|
||||
return at::_sparse_csr_tensor(crow_indices, col_indices, values,
|
||||
values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(REQ_GRAD_ARG));
|
||||
}
|
||||
throw std::runtime_error("sparse_csr_tensor(): invalid arguments");
|
||||
throw std::runtime_error("_sparse_csr_tensor(): invalid arguments");
|
||||
}
|
||||
|
||||
// Note [Ensuring sparse values and indices match devices]
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ at::Tensor indexing_tensor_from_data(
|
|||
at::ScalarType scalar_type,
|
||||
c10::optional<at::Device> device,
|
||||
PyObject* data);
|
||||
at::Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor _sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor _sparse_coo_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
void _validate_sparse_coo_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
torch.result_type,
|
||||
torch.scalar_tensor,
|
||||
torch.sparse_coo_tensor,
|
||||
torch.sparse_csr_tensor,
|
||||
torch._sparse_csr_tensor,
|
||||
torch.tril_indices,
|
||||
torch.triu_indices,
|
||||
torch.vander,
|
||||
|
|
@ -221,7 +221,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
Tensor._make_subclass,
|
||||
Tensor.stride,
|
||||
Tensor.unflatten,
|
||||
Tensor.to_sparse_csr,
|
||||
Tensor._to_sparse_csr,
|
||||
Tensor._reduce_ex_internal,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1074,9 +1074,9 @@ class TestCase(expecttest.TestCase):
|
|||
return values, crow_indices, col_indices
|
||||
|
||||
values, crow_indices, col_indices = random_sparse_csr(size[0], size[1], nnz)
|
||||
return torch.sparse_csr_tensor(crow_indices,
|
||||
col_indices,
|
||||
values, size=size, dtype=dtype, device=device)
|
||||
return torch._sparse_csr_tensor(crow_indices,
|
||||
col_indices,
|
||||
values, size=size, dtype=dtype, device=device)
|
||||
|
||||
def genSparseTensor(self, size, sparse_dim, nnz, is_uncoalesced, device, dtype):
|
||||
# Assert not given impossible combination, where the sparse dims have
|
||||
|
|
|
|||
Loading…
Reference in a new issue