[clone][sparse] Add torch._C._sparse namespace (#68672)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68672

This PR adds `python_module: sparse` to `native_function.yaml`.
These functions would appear in `torch._C._sparse` namespace instead of
just `torch`.

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D32517813

fbshipit-source-id: 7c3d6df57a24d7c7354d0fefe1b628dc89be9431
This commit is contained in:
Christian Puhrsch 2021-11-19 19:45:55 -08:00 committed by Facebook GitHub Bot
parent 95f4cd0ba9
commit 75955e4ef8
14 changed files with 151 additions and 51 deletions

View file

@ -236,6 +236,7 @@ libtorch_python_generated_sources = [
"torch/csrc/autograd/generated/python_nn_functions.cpp",
"torch/csrc/autograd/generated/python_fft_functions.cpp",
"torch/csrc/autograd/generated/python_linalg_functions.cpp",
"torch/csrc/autograd/generated/python_sparse_functions.cpp",
"torch/csrc/autograd/generated/python_special_functions.cpp",
]

View file

@ -542,6 +542,7 @@ The generated bindings are either exposed as methods on python_variable or funct
the torch._C._nn (marked with `python_module: nn`),
torch._C._fft (marked with `python_module: fft`),
torch._C._linalg (marked with `python_module: linalg`) objects,
torch._C._sparse (marked with `python_module: sparse`) objects,
or torch._C._special (marked with `python_module: special`) objects.
### Undefined tensor conventions

View file

@ -4763,12 +4763,15 @@
SparseCUDA: _sparse_sum_backward_cuda
- func: _sparse_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
python_module: sparse
variants: function
- func: _sparse_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
python_module: sparse
variants: function
- func: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
python_module: sparse
dispatch:
SparseCPU: softmax_sparse_cpu
SparseCUDA: softmax_sparse_cuda
@ -4779,12 +4782,15 @@
SparseCUDA: softmax_backward_sparse_cuda
- func: _sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
python_module: sparse
variants: function
- func: _sparse_log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
python_module: sparse
variants: function
- func: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
python_module: sparse
dispatch:
SparseCPU: log_softmax_sparse_cpu
SparseCUDA: log_softmax_sparse_cuda
@ -4994,6 +5000,7 @@
# Functionally the same as addmm, but we give it a different derivative formula
# that doesn't propagate gradients to non-present entries on sparse.
- func: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor
python_module: sparse
dispatch:
CompositeExplicitAutograd: _sparse_addmm

View file

@ -405,6 +405,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_nn_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_fft_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_linalg_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_sparse_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_special_functions.cpp"
)
@ -449,6 +450,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TOOLS_PATH}/autograd/templates/python_nn_functions.cpp"
"${TOOLS_PATH}/autograd/templates/python_fft_functions.cpp"
"${TOOLS_PATH}/autograd/templates/python_linalg_functions.cpp"
"${TOOLS_PATH}/autograd/templates/python_sparse_functions.cpp"
"${TOOLS_PATH}/autograd/templates/python_special_functions.cpp"
"${TOOLS_PATH}/autograd/templates/variable_factories.h"
"${TOOLS_PATH}/autograd/templates/annotated_fn_args.py.in"

View file

@ -1964,7 +1964,7 @@ graph(%Ra, %Rb):
return torch.sparse.addmm(input, input1, input2)
def test_sparse_addmm_alpha_beta(input, input1, input2):
return torch.sparse.addmm(input, input1, input2, 1.3, 1.5)
return torch.sparse.addmm(input, input1, input2, alpha=1.3, beta=1.5)
self.checkScript(test_sparse_addmm, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))
self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4)))

View file

@ -1,7 +1,7 @@
# Generates Python bindings for ATen functions
#
# The bindings are generated as methods on python_variable or functions on the
# torch._C._nn. torch._C._fft, torch._C._linalg or torch._C._special objects.
# torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._sparse or torch._C._special objects.
#
# Code tries to stick to the following rules:
@ -148,6 +148,9 @@ def is_py_fft_function(f: NativeFunction) -> bool:
def is_py_linalg_function(f: NativeFunction) -> bool:
return f.python_module == 'linalg'
def is_py_sparse_function(f: NativeFunction) -> bool:
return f.python_module == 'sparse'
def is_py_special_function(f: NativeFunction) -> bool:
return f.python_module == 'special'
@ -182,6 +185,9 @@ def gen(out: str, native_yaml_path: str, deprecated_yaml_path: str, template_pat
create_python_bindings(
fm, functions, is_py_linalg_function, 'torch.linalg', 'python_linalg_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_sparse_function, 'torch.sparse', 'python_sparse_functions.cpp', method=False)
create_python_bindings(
fm, functions, is_py_special_function, 'torch.special', 'python_special_functions.cpp', method=False)
@ -599,6 +605,7 @@ if(check_has_torch_function(self_)) {{
"torch.nn": "THPNNVariableFunctionsModule",
"torch.fft": "THPFFTVariableFunctionsModule",
"torch.linalg": "THPLinalgVariableFunctionsModule",
"torch.sparse": "THPSparseVariableFunctionsModule",
"torch.special": "THPSpecialVariableFunctionsModule",
}[module] if module else "THPVariableClass"

View file

@ -0,0 +1,60 @@
// ${generated_comment}
#include "torch/csrc/Device.h"
#include "torch/csrc/DynamicTypes.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/autograd/python_sparse_functions.h"
#include "torch/csrc/autograd/python_variable.h"
#include "torch/csrc/autograd/utils/wrap_outputs.h"
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
#include "torch/csrc/utils/pycfunction_helpers.h"
#include "torch/csrc/utils/python_arg_parser.h"
#include "torch/csrc/utils/structseq.h"
using at::Tensor;
using at::Scalar;
using at::ScalarType;
using at::MemoryFormat;
using at::Generator;
using at::IntArrayRef;
using at::TensorList;
using namespace torch::autograd::utils;
namespace torch { namespace autograd {
// generated forward declarations start here
${py_forwards}
static PyMethodDef sparse_functions[] = {
${py_method_defs}
{NULL}
};
static PyObject* THPSparseVariableFunctionsModule = NULL;
void initSparseFunctions(PyObject* module) {
static struct PyModuleDef def = {
PyModuleDef_HEAD_INIT,
"torch._C._sparse",
NULL,
-1,
sparse_functions
};
PyObject* sparse = PyModule_Create(&def);
THPSparseVariableFunctionsModule = sparse;
if (!sparse) {
throw python_error();
}
// steals a reference to sparse
if (PyModule_AddObject(module, "_sparse", sparse) != 0) {
throw python_error();
}
}
// generated methods start here
${py_methods}
}} // namespace torch::autograd

View file

@ -22,6 +22,7 @@ GENERATED_CPP = [
"autograd/generated/python_nn_functions.cpp",
"autograd/generated/python_fft_functions.cpp",
"autograd/generated/python_linalg_functions.cpp",
"autograd/generated/python_sparse_functions.cpp",
"autograd/generated/python_special_functions.cpp",
"autograd/generated/python_torch_functions_0.cpp",
"autograd/generated/python_torch_functions_1.cpp",
@ -864,6 +865,7 @@ def glob_libtorch_python_sources(gencode_pattern = ":generate-code[{}]"):
"autograd/generated/python_nn_functions.cpp",
"autograd/generated/python_fft_functions.cpp",
"autograd/generated/python_linalg_functions.cpp",
"autograd/generated/python_sparse_functions.cpp",
"autograd/generated/python_special_functions.cpp",
"autograd/generated/python_torch_functions_0.cpp",
"autograd/generated/python_torch_functions_1.cpp",

View file

@ -37,6 +37,7 @@
#include <torch/csrc/autograd/python_nn_functions.h>
#include <torch/csrc/autograd/python_fft_functions.h>
#include <torch/csrc/autograd/python_linalg_functions.h>
#include <torch/csrc/autograd/python_sparse_functions.h>
#include <torch/csrc/autograd/python_special_functions.h>
#include <torch/csrc/autograd/python_legacy_variable.h>
#include <torch/csrc/autograd/python_variable.h>
@ -837,6 +838,7 @@ PyObject* initModule() {
torch::autograd::initNNFunctions(module);
torch::autograd::initFFTFunctions(module);
torch::autograd::initLinalgFunctions(module);
torch::autograd::initSparseFunctions(module);
torch::autograd::initSpecialFunctions(module);
torch::autograd::init_legacy_variable(module);
torch::python::init_bindings(module);

View file

@ -13,6 +13,7 @@
#include <torch/nn.h>
#include <torch/optim.h>
#include <torch/serialize.h>
#include <torch/sparse.h>
#include <torch/special.h>
#include <torch/types.h>
#include <torch/utils.h>

View file

@ -0,0 +1,8 @@
#pragma once
#include <ATen/ATen.h>
namespace torch {
namespace sparse {
}} // torch::sparse

View file

@ -0,0 +1,7 @@
#pragma once
namespace torch { namespace autograd {
void initSparseFunctions(PyObject* module);
}} // namespace torch::autograd

View file

@ -13,7 +13,7 @@ from typing import Dict, Optional
_builtin_table: Optional[Dict[int, str]] = None
_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._special) # type: ignore[attr-defined] # noqa: B950
_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950
_builtin_ops = [
# Pairs of (function, op_name)

View file

@ -2,6 +2,7 @@
from typing import Optional, Tuple, List, Union
import torch
from torch._C import _add_docstr, _sparse # type: ignore[attr-defined]
from torch import Tensor
# A workaround to support both TorchScript and MyPy:
@ -24,22 +25,21 @@ __all__ = [
]
def addmm(mat: Tensor, mat1: Tensor, mat2: Tensor,
beta: float = 1., alpha: float = 1.) -> Tensor:
r"""
This function does exact same thing as :func:`torch.addmm` in the forward,
except that it supports backward for sparse matrix :attr:`mat1`. :attr:`mat1`
need to have `sparse_dim = 2`. Note that the gradients of :attr:`mat1` is a
coalesced sparse tensor.
addmm = _add_docstr(_sparse._sparse_addmm, r"""
sparse.addmm(mat, mat1, mat2, *, beta=1., alpha=1.) -> Tensor
Args:
mat (Tensor): a dense matrix to be added
mat1 (Tensor): a sparse matrix to be multiplied
mat2 (Tensor): a dense matrix to be multiplied
beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
"""
return torch._sparse_addmm(mat, mat1, mat2, beta=beta, alpha=alpha)
This function does exact same thing as :func:`torch.addmm` in the forward,
except that it supports backward for sparse matrix :attr:`mat1`. :attr:`mat1`
need to have `sparse_dim = 2`. Note that the gradients of :attr:`mat1` is a
coalesced sparse tensor.
Args:
mat (Tensor): a dense matrix to be added
mat1 (Tensor): a sparse matrix to be multiplied
mat2 (Tensor): a dense matrix to be multiplied
beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
""")
def mm(mat1: Tensor, mat2: Tensor) -> Tensor:
@ -159,45 +159,47 @@ def sum(input: Tensor, dim: DimOrDims = None,
return torch._sparse_sum(input, dtype=dtype)
def softmax(input: Tensor, dim: int, dtype: Optional[DType] = None) -> Tensor:
r"""Applies a softmax function.
softmax = _add_docstr(_sparse._sparse_softmax, r"""
sparse.softmax(input, dim, *, dtype=None) -> Tensor
Softmax is defined as:
Applies a softmax function.
:math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}`
Softmax is defined as:
where :math:`i, j` run over sparse tensor indices and unspecified
entries are ignores. This is equivalent to defining unspecified
entries as negative infinity so that :math:`exp(x_k) = 0` when the
entry with index :math:`k` has not specified.
:math:`\text{Softmax}(x_{i}) = \frac{exp(x_i)}{\sum_j exp(x_j)}`
It is applied to all slices along `dim`, and will re-scale them so
that the elements lie in the range `[0, 1]` and sum to 1.
where :math:`i, j` run over sparse tensor indices and unspecified
entries are ignores. This is equivalent to defining unspecified
entries as negative infinity so that :math:`exp(x_k) = 0` when the
entry with index :math:`k` has not specified.
Args:
input (Tensor): input
dim (int): A dimension along which softmax will be computed.
dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the input tensor is
casted to :attr:`dtype` before the operation is
performed. This is useful for preventing data type
overflows. Default: None
"""
return torch._sparse_softmax(input, dim, dtype=dtype)
It is applied to all slices along `dim`, and will re-scale them so
that the elements lie in the range `[0, 1]` and sum to 1.
Args:
input (Tensor): input
dim (int): A dimension along which softmax will be computed.
dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the input tensor is
casted to :attr:`dtype` before the operation is
performed. This is useful for preventing data type
overflows. Default: None
""")
def log_softmax(input: Tensor, dim: int, dtype: Optional[DType] = None) -> Tensor:
r"""Applies a softmax function followed by logarithm.
log_softmax = _add_docstr(_sparse._sparse_log_softmax, r"""
sparse.log_softmax(input, dim, *, dtype=None) -> Tensor
See :class:`~torch.sparse.softmax` for more details.
Applies a softmax function followed by logarithm.
Args:
input (Tensor): input
dim (int): A dimension along which softmax will be computed.
dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the input tensor is
casted to :attr:`dtype` before the operation is
performed. This is useful for preventing data type
overflows. Default: None
"""
return torch._sparse_log_softmax(input, dim, dtype=dtype)
See :class:`~torch.sparse.softmax` for more details.
Args:
input (Tensor): input
dim (int): A dimension along which softmax will be computed.
dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the input tensor is
casted to :attr:`dtype` before the operation is
performed. This is useful for preventing data type
overflows. Default: None
""")