Unsafe Sparse Compressed tensor factory function

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75961

Approved by: https://github.com/cpuhrsch
This commit is contained in:
Pearu Peterson 2022-04-28 22:51:10 +03:00 committed by PyTorch MergeBot
parent 7afe4afd86
commit ff10e45993
6 changed files with 135 additions and 55 deletions

View file

@ -5400,6 +5400,8 @@
- func: _sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: _sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
- func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

View file

@ -17,6 +17,7 @@
#else
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
#include <ATen/ops/_nnz_native.h>
#include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
@ -210,26 +211,31 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind
}
void _validate_sparse_compressed_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size, Layout layout) {
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, layout);
void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout) {
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout);
}
void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr);
}
// Construction of CSR tensors.
SparseCsrTensor new_csr_tensor(const TensorOptions& options) {
// Construction of CSR, CSC, BSR, and BSC tensors.
// Note: The usage of "Csr" in names like SparseCsrTensor,
// SparseCsrCPU, SparseCsrCUDA, and SparseCsrTensorImpl exists because
// of historical reasons (that ought to be removed in future) and does
// not mean that the corresponding functionality would be CSR layout
// only specific.
SparseCsrTensor new_compressed_tensor(const TensorOptions& options) {
// TODO: remove this comment after enabling autograd support for CSR tensor
// constructor.
// TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch());
Layout layout = options.layout();
TORCH_INTERNAL_ASSERT(layout == kSparseCsr);
Layout layout = AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(options.layout(), "new_compressed_tensor", [&] { return the_layout; });
DispatchKey dispatch_key;
TORCH_CHECK_NOT_IMPLEMENTED(
options.device().type() == kCPU || options.device().type() == kCUDA,
"Could not run '", "sparse_csr_tensor", "' from the '", options.device(), "' device.)");
"Could not run 'new_compressed_tensor' from the '", options.device(), "' device.)");
if (options.device().is_cuda()) {
dispatch_key = DispatchKey::SparseCsrCUDA;
@ -241,21 +247,57 @@ SparseCsrTensor new_csr_tensor(const TensorOptions& options) {
DispatchKeySet(dispatch_key), layout, options.dtype());
}
Tensor _sparse_csr_tensor_unsafe(const Tensor& crow_indices, const Tensor& col_indices,
const Tensor& values,
IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
SparseCsrTensor self = new_csr_tensor(options);
get_sparse_csr_impl(self)->set_member_tensors(crow_indices, col_indices, values, size);
Tensor _sparse_compressed_tensor_unsafe(const Tensor& compressed_indices,
const Tensor& plain_indices,
const Tensor& values,
IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
if (!layout) {
AT_ERROR("sparse_compressed_tensor_unsafe expected sparse compressed tensor layout but got none");
}
Layout layout_ = layout.value();
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{});
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
SparseCsrTensor self = new_compressed_tensor(options);
get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
return self;
}
template <Layout required_layout>
Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indices,
const Tensor& plain_indices,
const Tensor& values,
IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
Layout layout_ = layout.value_or(required_layout);
TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_);
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
SparseCsrTensor self = new_compressed_tensor(options);
get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
return self;
}
#define SPARSE_COMPRESSED_TENSOR_UNSAFE(KIND, REQUIRED_LAYOUT) \
Tensor _sparse_##KIND##_tensor_unsafe(const Tensor& compressed_indices, \
const Tensor& plain_indices, \
const Tensor& values, \
IntArrayRef size, \
c10::optional<ScalarType> dtype, \
c10::optional<Layout> layout, \
c10::optional<Device> device, \
c10::optional<bool> pin_memory) { \
return _sparse_compressed_tensor_unsafe_template<REQUIRED_LAYOUT>(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); \
}
SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr);
// TODO: This constructor should probably use an ATen abstract method in order
// to make autograd dispatch available for the CSR constructor. See the relevant
// note in native_functions.yaml.

View file

@ -93,8 +93,7 @@ _SKIP_PYTHON_BINDINGS = [
".*_forward_out",
"_unsafe_view",
"tensor",
"_?sparse_coo_tensor.*",
"_?sparse_csr_tensor.*",
"_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*",
"_arange.*",
"_range.*",
"linspace.*",

View file

@ -419,7 +419,7 @@ static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args,
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
}
jit::tracer::warn("torch.sparse_csr_tensor", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::sparse_csr_tensor_ctor(
@ -429,27 +429,26 @@ static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args,
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable__sparse_csr_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"_sparse_csr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
});
ParsedArgs<7> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
}
jit::tracer::warn("torch._sparse_csr_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::_sparse_csr_tensor_unsafe_ctor(
torch::tensors::get_default_dispatch_key(),
torch::tensors::get_default_scalar_type(),
r));
END_HANDLE_TH_ERRORS
#define THPVARIABLE_SPARSE_COMPRESSED_CTOR(NAME, NARGS, SIGNATURES) \
static PyObject * THPVariable_ ## NAME(PyObject* self, PyObject* args, PyObject* kwargs) \
{ \
HANDLE_TH_ERRORS \
static PythonArgParser parser SIGNATURES ; \
ParsedArgs<NARGS> parsed_args; \
auto r = parser.parse(args, kwargs, parsed_args); \
if (r.has_torch_function()) { \
return handle_torch_function(r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); \
} \
jit::tracer::warn("torch." # NAME, jit::tracer::WARN_CONSTRUCTOR); \
return THPVariable_Wrap(torch::utils::NAME ## _ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), r)); \
END_HANDLE_TH_ERRORS \
}
THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_compressed_tensor_unsafe, 8,
({"_sparse_compressed_tensor_unsafe(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool requires_grad=False)"}))
THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_csr_tensor_unsafe, 7,
({"_sparse_csr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)"}))
static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
@ -796,6 +795,7 @@ static PyMethodDef torch_functions_manual[] = {
{"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
{"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
{"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
{"_sparse_compressed_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_compressed_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
{"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
{"_sparse_csr_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_csr_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
{"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},

View file

@ -20,6 +20,7 @@
#include <ATen/dlpack.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/TracerMode.h>
#include <c10/core/Backend.h>
#include <c10/core/DispatchKeySet.h>
@ -666,21 +667,19 @@ Tensor sparse_csr_tensor_ctor(
throw std::runtime_error("sparse_csr_tensor(): invalid arguments");
}
Tensor _sparse_csr_tensor_unsafe_ctor(
c10::DispatchKey dispatch_key,
at::ScalarType scalar_type,
PythonArgs& r) {
Tensor _sparse_compressed_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) {
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
enum {
ARG_CROW_INDICES = 0,
ARG_COL_INDICES,
ARG_VALUES,
ARG_SIZE,
ARG_TYPE,
ARG_DEVICE,
ARG_REQUIRES_GRAD,
ARGS_COUNT
ARG_COMPRESSED_INDICES = 0,
ARG_PLAIN_INDICES,
ARG_VALUES,
ARG_SIZE,
ARG_TYPE,
ARG_LAYOUT,
ARG_DEVICE,
ARG_REQUIRES_GRAD,
ARGS_COUNT
};
bool type_inference = r.isNone(ARG_TYPE);
const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key);
@ -690,15 +689,52 @@ Tensor _sparse_csr_tensor_unsafe_ctor(
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/type_inference);
Tensor crow_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_CROW_INDICES),
Tensor compressed_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COMPRESSED_INDICES),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/true);
Tensor col_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COL_INDICES),
Tensor plain_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_PLAIN_INDICES),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/true);
return at::_sparse_compressed_tensor_unsafe(compressed_indices, plain_indices, values, r.intlist(ARG_SIZE),
values.options().layout(r.layoutOptional(ARG_LAYOUT))).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
}
template <c10::Layout required_layout>
Tensor _sparse_compressed_tensor_unsafe_ctor_template(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) {
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
enum {
ARG_COMPRESSED_INDICES = 0,
ARG_PLAIN_INDICES,
ARG_VALUES,
ARG_SIZE,
ARG_TYPE,
ARG_DEVICE,
ARG_REQUIRES_GRAD,
ARGS_COUNT
};
bool type_inference = r.isNone(ARG_TYPE);
const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key);
const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE));
Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_VALUES),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/type_inference);
Tensor compressed_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COMPRESSED_INDICES),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/true);
return at::_sparse_csr_tensor_unsafe(crow_indices, col_indices, values, r.intlist(ARG_SIZE), values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
Tensor plain_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_PLAIN_INDICES),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/true);
return at::_sparse_compressed_tensor_unsafe(compressed_indices, plain_indices, values, r.intlist(ARG_SIZE),
values.options().layout(required_layout)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
}
Tensor _sparse_csr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) {
return _sparse_compressed_tensor_unsafe_ctor_template<c10::kSparseCsr>(dispatch_key, scalar_type, r);
}
// Note [Ensuring sparse values and indices match devices]

View file

@ -32,6 +32,7 @@ at::Tensor _sparse_csr_tensor_unsafe_ctor(
c10::DispatchKey dispatch_key,
at::ScalarType scalar_type,
PythonArgs& r);
at::Tensor _sparse_compressed_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r);
void _validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
void _validate_sparse_compressed_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor tensor_ctor(