From ff10e4599321fb6923b9522632a40fb414bbeeab Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Thu, 28 Apr 2022 22:51:10 +0300 Subject: [PATCH] Unsafe Sparse Compressed tensor factory function Pull Request resolved: https://github.com/pytorch/pytorch/pull/75961 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/native_functions.yaml | 2 + .../ATen/native/sparse/SparseCsrTensor.cpp | 78 ++++++++++++++----- tools/autograd/gen_python_functions.py | 3 +- .../python_torch_functions_manual.cpp | 40 +++++----- torch/csrc/utils/tensor_new.cpp | 66 ++++++++++++---- torch/csrc/utils/tensor_new.h | 1 + 6 files changed, 135 insertions(+), 55 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 2409122bdac..d130daf1f1a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5400,6 +5400,8 @@ - func: _sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: _sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + - func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor - func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index a7d66659d37..5802f460d09 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -17,6 +17,7 @@ #else #include #include +#include #include #include #include @@ -210,26 +211,31 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind } -void _validate_sparse_compressed_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size, Layout layout) { - _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, layout); +void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout) { + _validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout); } void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) { _validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr); } -// Construction of CSR tensors. -SparseCsrTensor new_csr_tensor(const TensorOptions& options) { +// Construction of CSR, CSC, BSR, and BSC tensors. + +// Note: The usage of "Csr" in names like SparseCsrTensor, +// SparseCsrCPU, SparseCsrCUDA, and SparseCsrTensorImpl exists because +// of historical reasons (that ought to be removed in future) and does +// not mean that the corresponding functionality would be CSR layout +// only specific. +SparseCsrTensor new_compressed_tensor(const TensorOptions& options) { // TODO: remove this comment after enabling autograd support for CSR tensor // constructor. // TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch()); - Layout layout = options.layout(); - TORCH_INTERNAL_ASSERT(layout == kSparseCsr); + Layout layout = AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(options.layout(), "new_compressed_tensor", [&] { return the_layout; }); DispatchKey dispatch_key; TORCH_CHECK_NOT_IMPLEMENTED( options.device().type() == kCPU || options.device().type() == kCUDA, - "Could not run '", "sparse_csr_tensor", "' from the '", options.device(), "' device.)"); + "Could not run 'new_compressed_tensor' from the '", options.device(), "' device.)"); if (options.device().is_cuda()) { dispatch_key = DispatchKey::SparseCsrCUDA; @@ -241,21 +247,57 @@ SparseCsrTensor new_csr_tensor(const TensorOptions& options) { DispatchKeySet(dispatch_key), layout, options.dtype()); } -Tensor _sparse_csr_tensor_unsafe(const Tensor& crow_indices, const Tensor& col_indices, - const Tensor& values, - IntArrayRef size, - c10::optional dtype, - c10::optional layout, - c10::optional device, - c10::optional pin_memory) { - TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); - - SparseCsrTensor self = new_csr_tensor(options); - get_sparse_csr_impl(self)->set_member_tensors(crow_indices, col_indices, values, size); +Tensor _sparse_compressed_tensor_unsafe(const Tensor& compressed_indices, + const Tensor& plain_indices, + const Tensor& values, + IntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + if (!layout) { + AT_ERROR("sparse_compressed_tensor_unsafe expected sparse compressed tensor layout but got none"); + } + Layout layout_ = layout.value(); + AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{}); + TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); + SparseCsrTensor self = new_compressed_tensor(options); + get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size); return self; } +template +Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indices, + const Tensor& plain_indices, + const Tensor& values, + IntArrayRef size, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { + Layout layout_ = layout.value_or(required_layout); + TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_); + TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory); + SparseCsrTensor self = new_compressed_tensor(options); + get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size); + return self; +} + +#define SPARSE_COMPRESSED_TENSOR_UNSAFE(KIND, REQUIRED_LAYOUT) \ + Tensor _sparse_##KIND##_tensor_unsafe(const Tensor& compressed_indices, \ + const Tensor& plain_indices, \ + const Tensor& values, \ + IntArrayRef size, \ + c10::optional dtype, \ + c10::optional layout, \ + c10::optional device, \ + c10::optional pin_memory) { \ + return _sparse_compressed_tensor_unsafe_template(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); \ + } + +SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr); + // TODO: This constructor should probably use an ATen abstract method in order // to make autograd dispatch available for the CSR constructor. See the relevant // note in native_functions.yaml. diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 121e2428f7d..da4c3633483 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -93,8 +93,7 @@ _SKIP_PYTHON_BINDINGS = [ ".*_forward_out", "_unsafe_view", "tensor", - "_?sparse_coo_tensor.*", - "_?sparse_csr_tensor.*", + "_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*", "_arange.*", "_range.*", "linspace.*", diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index db128a9e807..4cb18f015fa 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -419,7 +419,7 @@ static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args, auto r = parser.parse(args, kwargs, parsed_args); if (r.has_torch_function()) { return handle_torch_function( - r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); + r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); } jit::tracer::warn("torch.sparse_csr_tensor", jit::tracer::WARN_CONSTRUCTOR); return THPVariable_Wrap(torch::utils::sparse_csr_tensor_ctor( @@ -429,27 +429,26 @@ static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args, END_HANDLE_TH_ERRORS } -static PyObject * THPVariable__sparse_csr_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs) -{ - HANDLE_TH_ERRORS - static PythonArgParser parser({ - "_sparse_csr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", - }); - - ParsedArgs<7> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); - if (r.has_torch_function()) { - return handle_torch_function( - r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); - } - jit::tracer::warn("torch._sparse_csr_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR); - return THPVariable_Wrap(torch::utils::_sparse_csr_tensor_unsafe_ctor( - torch::tensors::get_default_dispatch_key(), - torch::tensors::get_default_scalar_type(), - r)); - END_HANDLE_TH_ERRORS +#define THPVARIABLE_SPARSE_COMPRESSED_CTOR(NAME, NARGS, SIGNATURES) \ +static PyObject * THPVariable_ ## NAME(PyObject* self, PyObject* args, PyObject* kwargs) \ +{ \ + HANDLE_TH_ERRORS \ + static PythonArgParser parser SIGNATURES ; \ + ParsedArgs parsed_args; \ + auto r = parser.parse(args, kwargs, parsed_args); \ + if (r.has_torch_function()) { \ + return handle_torch_function(r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); \ + } \ + jit::tracer::warn("torch." # NAME, jit::tracer::WARN_CONSTRUCTOR); \ + return THPVariable_Wrap(torch::utils::NAME ## _ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), r)); \ + END_HANDLE_TH_ERRORS \ } +THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_compressed_tensor_unsafe, 8, + ({"_sparse_compressed_tensor_unsafe(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool requires_grad=False)"})) +THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_csr_tensor_unsafe, 7, + ({"_sparse_csr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)"})) + static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS @@ -796,6 +795,7 @@ static PyMethodDef torch_functions_manual[] = { {"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, {"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, {"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"_sparse_compressed_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_compressed_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, {"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, {"_sparse_csr_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_csr_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, {"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 9202daeb727..01d78ee3559 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -666,21 +667,19 @@ Tensor sparse_csr_tensor_ctor( throw std::runtime_error("sparse_csr_tensor(): invalid arguments"); } -Tensor _sparse_csr_tensor_unsafe_ctor( - c10::DispatchKey dispatch_key, - at::ScalarType scalar_type, - PythonArgs& r) { +Tensor _sparse_compressed_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key))); TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key))); enum { - ARG_CROW_INDICES = 0, - ARG_COL_INDICES, - ARG_VALUES, - ARG_SIZE, - ARG_TYPE, - ARG_DEVICE, - ARG_REQUIRES_GRAD, - ARGS_COUNT + ARG_COMPRESSED_INDICES = 0, + ARG_PLAIN_INDICES, + ARG_VALUES, + ARG_SIZE, + ARG_TYPE, + ARG_LAYOUT, + ARG_DEVICE, + ARG_REQUIRES_GRAD, + ARGS_COUNT }; bool type_inference = r.isNone(ARG_TYPE); const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key); @@ -690,15 +689,52 @@ Tensor _sparse_csr_tensor_unsafe_ctor( /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/type_inference); - Tensor crow_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_CROW_INDICES), + Tensor compressed_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COMPRESSED_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true); - Tensor col_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COL_INDICES), + Tensor plain_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_PLAIN_INDICES), + /*copy_variables=*/false, /*copy_numpy=*/true, + /*type_inference=*/true); + return at::_sparse_compressed_tensor_unsafe(compressed_indices, plain_indices, values, r.intlist(ARG_SIZE), + values.options().layout(r.layoutOptional(ARG_LAYOUT))).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); +} + +template +Tensor _sparse_compressed_tensor_unsafe_ctor_template(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { + TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key))); + TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key))); + enum { + ARG_COMPRESSED_INDICES = 0, + ARG_PLAIN_INDICES, + ARG_VALUES, + ARG_SIZE, + ARG_TYPE, + ARG_DEVICE, + ARG_REQUIRES_GRAD, + ARGS_COUNT + }; + bool type_inference = r.isNone(ARG_TYPE); + const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key); + const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type); + at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE)); + Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_VALUES), + /*copy_variables=*/false, /*copy_numpy=*/true, + /*type_inference=*/type_inference); + + Tensor compressed_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COMPRESSED_INDICES), /*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true); - return at::_sparse_csr_tensor_unsafe(crow_indices, col_indices, values, r.intlist(ARG_SIZE), values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); + Tensor plain_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_PLAIN_INDICES), + /*copy_variables=*/false, /*copy_numpy=*/true, + /*type_inference=*/true); + return at::_sparse_compressed_tensor_unsafe(compressed_indices, plain_indices, values, r.intlist(ARG_SIZE), + values.options().layout(required_layout)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD)); +} + +Tensor _sparse_csr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) { + return _sparse_compressed_tensor_unsafe_ctor_template(dispatch_key, scalar_type, r); } // Note [Ensuring sparse values and indices match devices] diff --git a/torch/csrc/utils/tensor_new.h b/torch/csrc/utils/tensor_new.h index 7b19138c214..d3e9c32683c 100644 --- a/torch/csrc/utils/tensor_new.h +++ b/torch/csrc/utils/tensor_new.h @@ -32,6 +32,7 @@ at::Tensor _sparse_csr_tensor_unsafe_ctor( c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); +at::Tensor _sparse_compressed_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r); void _validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); void _validate_sparse_compressed_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); at::Tensor tensor_ctor(