mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Unsafe Sparse Compressed tensor factory function
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75961 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
7afe4afd86
commit
ff10e45993
6 changed files with 135 additions and 55 deletions
|
|
@ -5400,6 +5400,8 @@
|
|||
|
||||
- func: _sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
|
||||
- func: _sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
|
||||
- func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
|
||||
- func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
#else
|
||||
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
|
||||
#include <ATen/ops/_nnz_native.h>
|
||||
#include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/_validate_sparse_compressed_tensor_args_native.h>
|
||||
|
|
@ -210,26 +211,31 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind
|
|||
|
||||
}
|
||||
|
||||
void _validate_sparse_compressed_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size, Layout layout) {
|
||||
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, layout);
|
||||
void _validate_sparse_compressed_tensor_args(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor& values, IntArrayRef size, Layout layout) {
|
||||
_validate_sparse_compressed_tensor_args_worker(compressed_indices, plain_indices, values, size, layout);
|
||||
}
|
||||
|
||||
void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
|
||||
_validate_sparse_compressed_tensor_args_worker(crow_indices, col_indices, values, size, kSparseCsr);
|
||||
}
|
||||
|
||||
// Construction of CSR tensors.
|
||||
SparseCsrTensor new_csr_tensor(const TensorOptions& options) {
|
||||
// Construction of CSR, CSC, BSR, and BSC tensors.
|
||||
|
||||
// Note: The usage of "Csr" in names like SparseCsrTensor,
|
||||
// SparseCsrCPU, SparseCsrCUDA, and SparseCsrTensorImpl exists because
|
||||
// of historical reasons (that ought to be removed in future) and does
|
||||
// not mean that the corresponding functionality would be CSR layout
|
||||
// only specific.
|
||||
SparseCsrTensor new_compressed_tensor(const TensorOptions& options) {
|
||||
// TODO: remove this comment after enabling autograd support for CSR tensor
|
||||
// constructor.
|
||||
// TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch());
|
||||
Layout layout = options.layout();
|
||||
TORCH_INTERNAL_ASSERT(layout == kSparseCsr);
|
||||
Layout layout = AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(options.layout(), "new_compressed_tensor", [&] { return the_layout; });
|
||||
DispatchKey dispatch_key;
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
options.device().type() == kCPU || options.device().type() == kCUDA,
|
||||
"Could not run '", "sparse_csr_tensor", "' from the '", options.device(), "' device.)");
|
||||
"Could not run 'new_compressed_tensor' from the '", options.device(), "' device.)");
|
||||
|
||||
if (options.device().is_cuda()) {
|
||||
dispatch_key = DispatchKey::SparseCsrCUDA;
|
||||
|
|
@ -241,21 +247,57 @@ SparseCsrTensor new_csr_tensor(const TensorOptions& options) {
|
|||
DispatchKeySet(dispatch_key), layout, options.dtype());
|
||||
}
|
||||
|
||||
Tensor _sparse_csr_tensor_unsafe(const Tensor& crow_indices, const Tensor& col_indices,
|
||||
const Tensor& values,
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
|
||||
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
|
||||
|
||||
SparseCsrTensor self = new_csr_tensor(options);
|
||||
get_sparse_csr_impl(self)->set_member_tensors(crow_indices, col_indices, values, size);
|
||||
Tensor _sparse_compressed_tensor_unsafe(const Tensor& compressed_indices,
|
||||
const Tensor& plain_indices,
|
||||
const Tensor& values,
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
if (!layout) {
|
||||
AT_ERROR("sparse_compressed_tensor_unsafe expected sparse compressed tensor layout but got none");
|
||||
}
|
||||
Layout layout_ = layout.value();
|
||||
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{});
|
||||
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
|
||||
SparseCsrTensor self = new_compressed_tensor(options);
|
||||
get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
|
||||
return self;
|
||||
}
|
||||
|
||||
template <Layout required_layout>
|
||||
Tensor _sparse_compressed_tensor_unsafe_template(const Tensor& compressed_indices,
|
||||
const Tensor& plain_indices,
|
||||
const Tensor& values,
|
||||
IntArrayRef size,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
Layout layout_ = layout.value_or(required_layout);
|
||||
TORCH_CHECK(layout_ == required_layout, "sparse compressed layout must be ",required_layout, " but got ", layout_);
|
||||
TensorOptions options = TensorOptions().dtype(dtype).layout(layout_).device(device).pinned_memory(pin_memory);
|
||||
SparseCsrTensor self = new_compressed_tensor(options);
|
||||
get_sparse_csr_impl(self)->set_member_tensors(compressed_indices, plain_indices, values, size);
|
||||
return self;
|
||||
}
|
||||
|
||||
#define SPARSE_COMPRESSED_TENSOR_UNSAFE(KIND, REQUIRED_LAYOUT) \
|
||||
Tensor _sparse_##KIND##_tensor_unsafe(const Tensor& compressed_indices, \
|
||||
const Tensor& plain_indices, \
|
||||
const Tensor& values, \
|
||||
IntArrayRef size, \
|
||||
c10::optional<ScalarType> dtype, \
|
||||
c10::optional<Layout> layout, \
|
||||
c10::optional<Device> device, \
|
||||
c10::optional<bool> pin_memory) { \
|
||||
return _sparse_compressed_tensor_unsafe_template<REQUIRED_LAYOUT>(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); \
|
||||
}
|
||||
|
||||
SPARSE_COMPRESSED_TENSOR_UNSAFE(csr, kSparseCsr);
|
||||
|
||||
// TODO: This constructor should probably use an ATen abstract method in order
|
||||
// to make autograd dispatch available for the CSR constructor. See the relevant
|
||||
// note in native_functions.yaml.
|
||||
|
|
|
|||
|
|
@ -93,8 +93,7 @@ _SKIP_PYTHON_BINDINGS = [
|
|||
".*_forward_out",
|
||||
"_unsafe_view",
|
||||
"tensor",
|
||||
"_?sparse_coo_tensor.*",
|
||||
"_?sparse_csr_tensor.*",
|
||||
"_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*",
|
||||
"_arange.*",
|
||||
"_range.*",
|
||||
"linspace.*",
|
||||
|
|
|
|||
|
|
@ -419,7 +419,7 @@ static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args,
|
|||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
if (r.has_torch_function()) {
|
||||
return handle_torch_function(
|
||||
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
|
||||
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
|
||||
}
|
||||
jit::tracer::warn("torch.sparse_csr_tensor", jit::tracer::WARN_CONSTRUCTOR);
|
||||
return THPVariable_Wrap(torch::utils::sparse_csr_tensor_ctor(
|
||||
|
|
@ -429,27 +429,26 @@ static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args,
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * THPVariable__sparse_csr_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"_sparse_csr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
|
||||
});
|
||||
|
||||
ParsedArgs<7> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
if (r.has_torch_function()) {
|
||||
return handle_torch_function(
|
||||
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
|
||||
}
|
||||
jit::tracer::warn("torch._sparse_csr_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR);
|
||||
return THPVariable_Wrap(torch::utils::_sparse_csr_tensor_unsafe_ctor(
|
||||
torch::tensors::get_default_dispatch_key(),
|
||||
torch::tensors::get_default_scalar_type(),
|
||||
r));
|
||||
END_HANDLE_TH_ERRORS
|
||||
#define THPVARIABLE_SPARSE_COMPRESSED_CTOR(NAME, NARGS, SIGNATURES) \
|
||||
static PyObject * THPVariable_ ## NAME(PyObject* self, PyObject* args, PyObject* kwargs) \
|
||||
{ \
|
||||
HANDLE_TH_ERRORS \
|
||||
static PythonArgParser parser SIGNATURES ; \
|
||||
ParsedArgs<NARGS> parsed_args; \
|
||||
auto r = parser.parse(args, kwargs, parsed_args); \
|
||||
if (r.has_torch_function()) { \
|
||||
return handle_torch_function(r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); \
|
||||
} \
|
||||
jit::tracer::warn("torch." # NAME, jit::tracer::WARN_CONSTRUCTOR); \
|
||||
return THPVariable_Wrap(torch::utils::NAME ## _ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), r)); \
|
||||
END_HANDLE_TH_ERRORS \
|
||||
}
|
||||
|
||||
THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_compressed_tensor_unsafe, 8,
|
||||
({"_sparse_compressed_tensor_unsafe(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool requires_grad=False)"}))
|
||||
THPVARIABLE_SPARSE_COMPRESSED_CTOR(_sparse_csr_tensor_unsafe, 7,
|
||||
({"_sparse_csr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)"}))
|
||||
|
||||
static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
|
|
@ -796,6 +795,7 @@ static PyMethodDef torch_functions_manual[] = {
|
|||
{"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
{"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
{"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
{"_sparse_compressed_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_compressed_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
{"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
{"_sparse_csr_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_csr_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
{"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
#include <ATen/dlpack.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/SparseCsrTensorUtils.h>
|
||||
#include <ATen/TracerMode.h>
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
|
|
@ -666,21 +667,19 @@ Tensor sparse_csr_tensor_ctor(
|
|||
throw std::runtime_error("sparse_csr_tensor(): invalid arguments");
|
||||
}
|
||||
|
||||
Tensor _sparse_csr_tensor_unsafe_ctor(
|
||||
c10::DispatchKey dispatch_key,
|
||||
at::ScalarType scalar_type,
|
||||
PythonArgs& r) {
|
||||
Tensor _sparse_compressed_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) {
|
||||
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
|
||||
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
|
||||
enum {
|
||||
ARG_CROW_INDICES = 0,
|
||||
ARG_COL_INDICES,
|
||||
ARG_VALUES,
|
||||
ARG_SIZE,
|
||||
ARG_TYPE,
|
||||
ARG_DEVICE,
|
||||
ARG_REQUIRES_GRAD,
|
||||
ARGS_COUNT
|
||||
ARG_COMPRESSED_INDICES = 0,
|
||||
ARG_PLAIN_INDICES,
|
||||
ARG_VALUES,
|
||||
ARG_SIZE,
|
||||
ARG_TYPE,
|
||||
ARG_LAYOUT,
|
||||
ARG_DEVICE,
|
||||
ARG_REQUIRES_GRAD,
|
||||
ARGS_COUNT
|
||||
};
|
||||
bool type_inference = r.isNone(ARG_TYPE);
|
||||
const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key);
|
||||
|
|
@ -690,15 +689,52 @@ Tensor _sparse_csr_tensor_unsafe_ctor(
|
|||
/*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/type_inference);
|
||||
|
||||
Tensor crow_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_CROW_INDICES),
|
||||
Tensor compressed_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COMPRESSED_INDICES),
|
||||
/*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/true);
|
||||
|
||||
Tensor col_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COL_INDICES),
|
||||
Tensor plain_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_PLAIN_INDICES),
|
||||
/*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/true);
|
||||
return at::_sparse_compressed_tensor_unsafe(compressed_indices, plain_indices, values, r.intlist(ARG_SIZE),
|
||||
values.options().layout(r.layoutOptional(ARG_LAYOUT))).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
|
||||
}
|
||||
|
||||
template <c10::Layout required_layout>
|
||||
Tensor _sparse_compressed_tensor_unsafe_ctor_template(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) {
|
||||
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
|
||||
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
|
||||
enum {
|
||||
ARG_COMPRESSED_INDICES = 0,
|
||||
ARG_PLAIN_INDICES,
|
||||
ARG_VALUES,
|
||||
ARG_SIZE,
|
||||
ARG_TYPE,
|
||||
ARG_DEVICE,
|
||||
ARG_REQUIRES_GRAD,
|
||||
ARGS_COUNT
|
||||
};
|
||||
bool type_inference = r.isNone(ARG_TYPE);
|
||||
const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key);
|
||||
const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type);
|
||||
at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE));
|
||||
Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_VALUES),
|
||||
/*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/type_inference);
|
||||
|
||||
Tensor compressed_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COMPRESSED_INDICES),
|
||||
/*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/true);
|
||||
|
||||
return at::_sparse_csr_tensor_unsafe(crow_indices, col_indices, values, r.intlist(ARG_SIZE), values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
|
||||
Tensor plain_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_PLAIN_INDICES),
|
||||
/*copy_variables=*/false, /*copy_numpy=*/true,
|
||||
/*type_inference=*/true);
|
||||
return at::_sparse_compressed_tensor_unsafe(compressed_indices, plain_indices, values, r.intlist(ARG_SIZE),
|
||||
values.options().layout(required_layout)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
|
||||
}
|
||||
|
||||
Tensor _sparse_csr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r) {
|
||||
return _sparse_compressed_tensor_unsafe_ctor_template<c10::kSparseCsr>(dispatch_key, scalar_type, r);
|
||||
}
|
||||
|
||||
// Note [Ensuring sparse values and indices match devices]
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ at::Tensor _sparse_csr_tensor_unsafe_ctor(
|
|||
c10::DispatchKey dispatch_key,
|
||||
at::ScalarType scalar_type,
|
||||
PythonArgs& r);
|
||||
at::Tensor _sparse_compressed_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PythonArgs& r);
|
||||
void _validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
void _validate_sparse_compressed_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
|
||||
at::Tensor tensor_ctor(
|
||||
|
|
|
|||
Loading…
Reference in a new issue