From 4cf2c646c2d00a6f1ec6859ce28ce972366d343e Mon Sep 17 00:00:00 2001 From: Heitor Schueroff Date: Sun, 9 May 2021 04:49:35 -0700 Subject: [PATCH] Added torch.linalg.matrix_norm (#57127) Summary: This PR is focused on the API for `linalg.matrix_norm` and delegates computations to `linalg.norm` for the moment. The main difference between the norms is when `dim=None`. In this case - `linalg.norm` will compute a vector norm on the flattened input if `ord=None`, otherwise it requires the input to be either 1D or 2D in order to disambiguate between vector and matrix norm - `linalg.vector_norm` will flatten the input - `linalg.matrix_norm` will compute the norm over the last two dimensions, treating the input as batch of matrices In future PRs, the computations will be moved to `torch.linalg.matrix_norm` and `torch.norm` and `torch.linalg.norm` will delegate computations to either `linalg.vector_norm` or `linalg.matrix_norm` based on the arguments provided. Pull Request resolved: https://github.com/pytorch/pytorch/pull/57127 Reviewed By: mrshenli Differential Revision: D28186736 Pulled By: mruberry fbshipit-source-id: 99ce2da9d1c4df3d9dd82c0a312c9570da5caf25 --- aten/src/ATen/core/interned_strings.h | 1 + aten/src/ATen/native/LinearAlgebra.cpp | 64 ++++ aten/src/ATen/native/native_functions.yaml | 12 + docs/source/linalg.rst | 1 + test/test_linalg.py | 60 +++- torch/csrc/api/include/torch/linalg.h | 33 ++ torch/linalg/__init__.py | 339 +++++++++++------- torch/overrides.py | 1 + .../_internal/common_methods_invocations.py | 23 +- 9 files changed, 388 insertions(+), 146 deletions(-) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index ad2a1c23ffd..af5fa9c9fad 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -208,6 +208,7 @@ namespace c10 { _(aten, linalg_multi_dot) \ _(aten, linalg_norm) \ _(aten, linalg_vector_norm) \ + _(aten, linalg_matrix_norm) \ _(aten, append) \ _(aten, item) \ _(aten, format) \ diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index b7725a20a9e..be839051bd4 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -2313,6 +2313,70 @@ Tensor& linalg_vector_norm_out(const Tensor& self, const Scalar& ord, optional dtype) { + TORCH_CHECK( + self.ndimension() >= 2, + "linalg.matrix_norm(): input tensor must be a matrix or batch of matrices"); + ScalarType in_dtype = dtype.value_or(self.scalar_type()); + TORCH_CHECK( + in_dtype == kFloat || in_dtype == kDouble || in_dtype == kComplexFloat || + in_dtype == kComplexDouble, + "linalg.matrix_norm(): only supports the float, double, cfloat and cdouble dtypes, but got: ", + toString(in_dtype)); + TORCH_CHECK( + dim.size() == 2, "linalg.matrix_norm(): dim must be a 2-tuple of ints"); +} + +} // namespace + +Tensor linalg_matrix_norm( + const Tensor& self, + const Scalar& ord, + IntArrayRef dim, + bool keepdim, + optional dtype) { + check_linalg_matrix_norm_args(self, dim, dtype); + return at::native::linalg_norm(self, ord, dim, keepdim, dtype); +} + +Tensor& linalg_matrix_norm_out( + const Tensor& self, + const Scalar& ord, + IntArrayRef dim, + bool keepdim, + optional dtype, + Tensor& result) { + check_linalg_matrix_norm_args(self, dim, dtype); + return at::native::linalg_norm_out(self, ord, dim, keepdim, dtype, result); +} + +Tensor linalg_matrix_norm( + const Tensor& self, + std::string ord, + IntArrayRef dim, + bool keepdim, + optional dtype) { + check_linalg_matrix_norm_args(self, dim, dtype); + return at::native::linalg_norm(self, ord, dim, keepdim, dtype); +} + +Tensor& linalg_matrix_norm_out( + const Tensor& self, + std::string ord, + IntArrayRef dim, + bool keepdim, + optional dtype, + Tensor& result) { + check_linalg_matrix_norm_args(self, dim, dtype); + return at::native::linalg_norm_out(self, ord, dim, keepdim, dtype, result); +} + // Numerical or None norms Tensor linalg_norm(const Tensor& self, const optional& opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : toValueType(self.scalar_type())).device(self.device()); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b7139ae1de5..af9880cd51d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9636,6 +9636,18 @@ dispatch: CPU, CUDA: linalg_vector_norm_out +- func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + python_module: linalg + +- func: linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + +- func: linalg_matrix_norm.str_ord(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + python_module: linalg + +- func: linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + - func: linalg_svd.U(Tensor self, bool full_matrices=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) python_module: linalg diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 2472013dd93..757e8241b9f 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -22,6 +22,7 @@ Matrix Properties norm vector_norm + matrix_norm det slogdet cond diff --git a/test/test_linalg.py b/test/test_linalg.py index c99f7a908c6..9622c7e1bdf 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1499,24 +1499,29 @@ class TestLinalg(TestCase): for ord in ord_settings: run_test_case(input, ord, dim, keepdim) - # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that - # their matrix norm results match + # This test compares torch.linalg.norm, torch.linalg.matrix_norm and numpy.linalg.norm to + # ensure that their matrix norm results match. @skipMeta # https://github.com/pytorch/pytorch/issues/54082 @skipCUDAIfNoMagma @dtypes(torch.float, torch.double) @precisionOverride({torch.float32: 2e-5}) def test_norm_matrix(self, device, dtype): def run_test_case(input, ord, dim, keepdim): + msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' result = torch.linalg.norm(input, ord, dim, keepdim) input_numpy = input.cpu().numpy() result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) - msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' - self.assertEqual(result, result_numpy, msg=msg) + def check(op): + result = op(input, ord, dim, keepdim) + self.assertEqual(result, result_numpy, msg=msg) + result_out = torch.empty_like(result) + op(input, ord, dim, keepdim, out=result_out) + self.assertEqual(result, result_out, msg=msg) - result_out = torch.empty_like(result) - torch.linalg.norm(input, ord, dim, keepdim, out=result_out) - self.assertEqual(result, result_out, msg=msg) + check(torch.linalg.norm) + if ord is not None and dim is not None: + check(torch.linalg.matrix_norm) ord_matrix = [1, -1, 2, -2, inf, -inf, 'nuc', 'fro'] S = 10 @@ -1531,8 +1536,10 @@ class TestLinalg(TestCase): ((S, S, S, S), ord_matrix, (-3, 2)), ] L = 1_000 + if dtype == torch.double: test_cases.append(((L, L), ord_matrix, None)) + for keepdim in [True, False]: for input_size, ord_settings, dim in test_cases: input = torch.randn(*input_size, dtype=dtype, device=device) @@ -1765,6 +1772,29 @@ class TestLinalg(TestCase): result_n = np.linalg.norm(x_n, ord=ord) self.assertEqual(result, result_n, msg=msg) + @skipMeta # https://github.com/pytorch/pytorch/issues/54082 + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float, torch.double) + @precisionOverride({torch.float32: 2e-5}) + def test_matrix_norm(self, device, dtype): + # Test only inputs for which torch.linalg.matrix_norm diverges from torch.linalg.norm + A = make_tensor((2, 2, 2), device, dtype) + + with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm\(\):.*must be a matrix.*'): + torch.linalg.matrix_norm(make_tensor((2,), device, dtype)) + with self.assertRaisesRegex(RuntimeError, r'linalg.matrix_norm\(\):.*must be a 2-tuple.*'): + torch.linalg.matrix_norm(A, dim=(0,)) + with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'): + torch.linalg.matrix_norm(A, ord=0) + with self.assertRaisesRegex(RuntimeError, r'.*not supported.*'): + torch.linalg.matrix_norm(A, ord=3.0) + + # Test dim=None behavior + ref = torch.linalg.norm(A, dim=(-2, -1)) + res = torch.linalg.matrix_norm(A) + self.assertEqual(ref, res) + # Test that linal.norm gives the same result as numpy when inputs # contain extreme values (inf, -inf, nan) @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @@ -1864,15 +1894,22 @@ class TestLinalg(TestCase): def run_test_case(input, ord, dim, keepdim, should_error): msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' input_numpy = input.cpu().numpy() + ops = [torch.linalg.norm] + + if ord is not None and dim is not None: + ops.append(torch.linalg.matrix_norm) + if should_error: with self.assertRaises(ValueError): np.linalg.norm(input_numpy, ord, dim, keepdim) - with self.assertRaises(IndexError): - torch.linalg.norm(input, ord, dim, keepdim) + for op in ops: + with self.assertRaises(IndexError): + op(input, ord, dim, keepdim) else: result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) - result = torch.linalg.norm(input, ord, dim, keepdim) - self.assertEqual(result, result_numpy, msg=msg) + for op in ops: + result = op(input, ord, dim, keepdim) + self.assertEqual(result, result_numpy, msg=msg) ord_matrix = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf, None] S = 10 @@ -1886,6 +1923,7 @@ class TestLinalg(TestCase): ((0, 0, S), [1, 2, inf, -1, -2, -inf], (0, 1)), ((0, 0, S), [1, 2, inf, -1, -2, -inf], (1, 0)), ] + for keepdim in [True, False]: for input_size, error_ords, dim in test_cases: input = torch.randn(*input_size, dtype=dtype, device=device) diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index 7b1fa112855..7b7dc7a65df 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -96,6 +96,22 @@ inline Tensor& vector_norm_out(Tensor& result, const Tensor& self, Scalar ord, o return torch::linalg_vector_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } +inline Tensor matrix_norm(const Tensor& self, const Scalar& ord, IntArrayRef dim, bool keepdim, optional dtype) { + return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype); +} + +inline Tensor& matrix_norm_out(const Tensor& self, const Scalar& ord, IntArrayRef dim, bool keepdim, optional dtype, Tensor& result) { + return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype); +} + +inline Tensor matrix_norm(const Tensor& self, std::string ord, IntArrayRef dim, bool keepdim, optional dtype) { + return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype); +} + +inline Tensor& matrix_norm_out(const Tensor& self, std::string ord, IntArrayRef dim, bool keepdim, optional dtype, Tensor& result) { + return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype); +} + inline Tensor matrix_power(const Tensor& self, int64_t n) { return torch::linalg_matrix_power(self, n); } @@ -323,6 +339,23 @@ inline Tensor& vector_norm_out(Tensor& result, const Tensor& self, Scalar ord, o return detail::vector_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); } +/// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_norm +inline Tensor matrix_norm(const Tensor& self, const Scalar& ord, IntArrayRef dim, bool keepdim, optional dtype) { + return detail::matrix_norm(self, ord, dim, keepdim, dtype); +} + +inline Tensor& matrix_norm_out(const Tensor& self, const Scalar& ord, IntArrayRef dim, bool keepdim, optional dtype, Tensor& result) { + return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result); +} + +inline Tensor matrix_norm(const Tensor& self, std::string ord, IntArrayRef dim, bool keepdim, optional dtype) { + return detail::matrix_norm(self, ord, dim, keepdim, dtype); +} + +inline Tensor& matrix_norm_out(const Tensor& self, std::string ord, IntArrayRef dim, bool keepdim, optional dtype, Tensor& result) { + return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result); +} + /// See https://pytorch.org/docs/master/linalg.html#torch.linalg.matrix_power inline Tensor matrix_power(const Tensor& self, int64_t n) { return detail::matrix_power(self, n); diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 757c9855688..d7498dcbe4e 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -947,139 +947,6 @@ Examples:: [1, 2, 2, 2]]) """) -vector_norm = _add_docstr(_linalg.linalg_vector_norm, r""" -linalg.vector_norm(A, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor - -Computes a vector norm. - -If :attr:`A` is complex valued, it computes the norm of :attr:`A`\ `.abs()` - -Supports inputs of float, double, cfloat and cdouble dtypes. -Also supports batched inputs, and, if the input is batched, the output is batched with the same dimensions. - -- If :attr:`dim`\ `= None`, :attr:`A` will be flattened before the norm is computed. -- If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions - and the other dimensions will be treated as batch dimensions. - -:attr:`ord` defines the vector norm that is computed. The following norms are supported: - -====================== ======================================================== -:attr:`ord` vector norm -====================== ======================================================== -`2` (default) `2`-norm -`inf` `max(abs(x))` -`-inf` `min(abs(x))` -`0` `sum(x != 0)` -other `int` or `float` `sum(abs(x)**\ `:attr:`ord`\ `)**(1./\ `:attr:`ord`\ `)` -====================== ======================================================== - -where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. - -Args: - A (Tensor): tensor of shape `(*, n)` where `*` is zero or more batch dimensions. - ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` - dim (int, Tuple[int], optional): dimensions over which to compute - the norm. See above for the behavior when :attr:`dim`\ `= None`. - Default: `None` - keepdim (bool, optional): If set to `True`, the reduced dimensions are retained - in the result as dimensions with shape one. Default: `False` - -Keyword args: - out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. - dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to - :attr:`dtype` before performing the operation, and the returned tensor's type - will be :attr:`dtype`. Default: `None` - -Returns: - A real-valued tensor, even when :attr:`A` is complex. - -Examples:: - - >>> from torch import linalg as LA - >>> a = torch.arange(9, dtype=torch.float) - 4 - >>> a - tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) - >>> b = a.reshape((3, 3)) - >>> b - tensor([[-4., -3., -2.], - [-1., 0., 1.], - [ 2., 3., 4.]]) - >>> LA.vector_norm(a, ord=3.5) - tensor(5.4345) - >>> LA.vector_norm(b, ord=3.5) - tensor(5.4345) -""") - -multi_dot = _add_docstr(_linalg.linalg_multi_dot, r""" -linalg.multi_dot(tensors, *, out=None) - -Efficiently multiplies two or more matrices by reordering the multiplications so that -the fewest arithmetic operations are performed. - -Supports inputs of float, double, cfloat and cdouble dtypes. -This function does not support batched inputs. -Every tensor in :attr:`tensors` must be 2D, except for the first and last which -may be 1D. If the first tensor is a 1D vector of shape `(n,)` it is treated as a row vector -of shape `(1, n)`, similarly if the last tensor is a 1D vector of shape `(n,)` it is treated -as a column vector of shape `(n, 1)`. - -If the first and last tensors are matrices, the output will be a matrix. -However, if either is a 1D vector, then the output will be a 1D vector. - -Differences with `numpy.linalg.multi_dot`: - -- Unlike `numpy.linalg.multi_dot`, the first and last tensors must either be 1D or 2D - whereas NumPy allows them to be nD - -.. warning:: This function does not broadcast. - -.. note:: This function is implemented by chaining :func:`torch.mm` calls after - computing the optimal matrix multiplication order. - -.. note:: The cost of multiplying two matrices with shapes `(a, b)` and `(b, c)` is - `a * b * c`. Given matrices `A`, `B`, `C` with shapes `(10, 100)`, - `(100, 5)`, `(5, 50)` respectively, we can calculate the cost of different - multiplication orders as follows: - - .. math:: - - \begin{align*} - \operatorname{cost}((AB)C) &= 10 \times 100 \times 5 + 10 \times 5 \times 50 = 7500 \\ - \operatorname{cost}(A(BC)) &= 10 \times 100 \times 50 + 100 \times 5 \times 50 = 75000 - \end{align*} - - In this case, multiplying `A` and `B` first followed by `C` is 10 times faster. - -Args: - tensors (Sequence[Tensor]): two or more tensors to multiply. The first and last - tensors may be 1D or 2D. Every other tensor must be 2D. - -Keyword args: - out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. - -Examples:: - - >>> from torch.linalg import multi_dot - - >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])]) - tensor(8) - >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])]) - tensor([8]) - >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])]) - tensor([[8]]) - - >>> a = torch.arange(2 * 3).view(2, 3) - >>> b = torch.arange(3 * 2).view(3, 2) - >>> c = torch.arange(2 * 2).view(2, 2) - >>> multi_dot((a, b, c)) - tensor([[ 26, 49], - [ 80, 148]]) - - >>> multi_dot((a.to(torch.float), torch.empty(3, 0), torch.empty(0, 2))) - tensor([[0., 0.], - [0., 0.]]) -""") - norm = _add_docstr(_linalg.linalg_norm, r""" linalg.norm(A, ord=None, dim=None, keepdim=False, *, out=None, dtype=None) -> Tensor @@ -1203,6 +1070,212 @@ Using the :attr:`dim` argument to compute matrix norms:: (tensor(3.7417), tensor(11.2250)) """) +vector_norm = _add_docstr(_linalg.linalg_vector_norm, r""" +linalg.vector_norm(A, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor + +Computes a vector norm. + +If :attr:`A` is complex valued, it computes the norm of :attr:`A`\ `.abs()` + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batched inputs, and, if the input is batched, the output is batched with the same dimensions. + +- If :attr:`dim`\ `= None`, :attr:`A` will be flattened before the norm is computed. +- If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions + and the other dimensions will be treated as batch dimensions. + +:attr:`ord` defines the vector norm that is computed. The following norms are supported: + +====================== ======================================================== +:attr:`ord` vector norm +====================== ======================================================== +`2` (default) `2`-norm +`inf` `max(abs(x))` +`-inf` `min(abs(x))` +`0` `sum(x != 0)` +other `int` or `float` `sum(abs(x)**\ `:attr:`ord`\ `)**(1./\ `:attr:`ord`\ `)` +====================== ======================================================== + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +Args: + A (Tensor): tensor of shape `(*, n)` where `*` is zero or more batch dimensions. + ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` + dim (int, Tuple[int], optional): dimensions over which to compute + the norm. See above for the behavior when :attr:`dim`\ `= None`. + Default: `None` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to + :attr:`dtype` before performing the operation, and the returned tensor's type + will be :attr:`dtype`. Default: `None` + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Examples:: + + >>> from torch import linalg as LA + >>> a = torch.arange(9, dtype=torch.float) - 4 + >>> a + tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) + >>> b = a.reshape((3, 3)) + >>> b + tensor([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]]) + >>> LA.vector_norm(a, ord=3.5) + tensor(5.4345) + >>> LA.vector_norm(b, ord=3.5) + tensor(5.4345) +""") + +matrix_norm = _add_docstr(_linalg.linalg_matrix_norm, r""" +linalg.matrix_norm(A, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None, out=None) -> Tensor + +Computes a matrix norm. + +If :attr:`A` is complex valued, it computes the norm of :attr:`A`\ `.abs()` + +Supports inputs of float, double, cfloat and cdouble dtypes. +Also supports batched inputs, and, if the input is batched, the output is batched with the same dimensions. + +The norm will be computed over the dimensions specified by the 2-tuple :attr:`dim` +and the other dimensions will be treated as batch dimensions. + +:attr:`ord` defines the matrix norm that is computed. The following norms are supported: + +====================== ======================================================== +:attr:`ord` matrix norm +====================== ======================================================== +`'fro'` (default) Frobenius norm +`'nuc'` nuclear norm +`inf` `max(sum(abs(x), dim=1))` +`-inf` `min(sum(abs(x), dim=1))` +`1` `max(sum(abs(x), dim=0))` +`-1` `min(sum(abs(x), dim=0))` +`2` largest singular value +`-2` smallest singular value +====================== ======================================================== + +where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + +Args: + A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. + ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'fro'` + dim (Tuple[int, int], optional): dimensions over which to compute the norm. Default: `(-2, -1)` + keepdim (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to + :attr:`dtype` before performing the operation, and the returned tensor's type + will be :attr:`dtype`. Default: `None` + +Returns: + A real-valued tensor, even when :attr:`A` is complex. + +Examples:: + + >>> from torch import linalg as LA + >>> A = torch.arange(9, dtype=torch.float).reshape(3, 3) + >>> A + tensor([[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]) + >>> LA.matrix_norm(A) + tensor(14.2829) + >>> LA.matrix_norm(A, ord=-1) + tensor(9.) + >>> B = A.expand(2, -1, -1) + >>> B + tensor([[[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]], + + [[0., 1., 2.], + [3., 4., 5.], + [6., 7., 8.]]]) + >>> LA.matrix_norm(B) + tensor([14.2829, 14.2829]) + >>> LA.matrix_norm(B, dim=(0, 2)) + tensor([ 3.1623, 10.0000, 17.2627]) +""") + +multi_dot = _add_docstr(_linalg.linalg_multi_dot, r""" +linalg.multi_dot(tensors, *, out=None) + +Efficiently multiplies two or more matrices by reordering the multiplications so that +the fewest arithmetic operations are performed. + +Supports inputs of float, double, cfloat and cdouble dtypes. +This function does not support batched inputs. +Every tensor in :attr:`tensors` must be 2D, except for the first and last which +may be 1D. If the first tensor is a 1D vector of shape `(n,)` it is treated as a row vector +of shape `(1, n)`, similarly if the last tensor is a 1D vector of shape `(n,)` it is treated +as a column vector of shape `(n, 1)`. + +If the first and last tensors are matrices, the output will be a matrix. +However, if either is a 1D vector, then the output will be a 1D vector. + +Differences with `numpy.linalg.multi_dot`: + +- Unlike `numpy.linalg.multi_dot`, the first and last tensors must either be 1D or 2D + whereas NumPy allows them to be nD + +.. warning:: This function does not broadcast. + +.. note:: This function is implemented by chaining :func:`torch.mm` calls after + computing the optimal matrix multiplication order. + +.. note:: The cost of multiplying two matrices with shapes `(a, b)` and `(b, c)` is + `a * b * c`. Given matrices `A`, `B`, `C` with shapes `(10, 100)`, + `(100, 5)`, `(5, 50)` respectively, we can calculate the cost of different + multiplication orders as follows: + + .. math:: + + \begin{align*} + \operatorname{cost}((AB)C) &= 10 \times 100 \times 5 + 10 \times 5 \times 50 = 7500 \\ + \operatorname{cost}(A(BC)) &= 10 \times 100 \times 50 + 100 \times 5 \times 50 = 75000 + \end{align*} + + In this case, multiplying `A` and `B` first followed by `C` is 10 times faster. + +Args: + tensors (Sequence[Tensor]): two or more tensors to multiply. The first and last + tensors may be 1D or 2D. Every other tensor must be 2D. + +Keyword args: + out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. + +Examples:: + + >>> from torch.linalg import multi_dot + + >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])]) + tensor(8) + >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])]) + tensor([8]) + >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])]) + tensor([[8]]) + + >>> a = torch.arange(2 * 3).view(2, 3) + >>> b = torch.arange(3 * 2).view(3, 2) + >>> c = torch.arange(2 * 2).view(2, 2) + >>> multi_dot((a, b, c)) + tensor([[ 26, 49], + [ 80, 148]]) + + >>> multi_dot((a.to(torch.float), torch.empty(3, 0), torch.empty(0, 2))) + tensor([[0., 0.], + [0., 0.]]) +""") + svd = _add_docstr(_linalg.linalg_svd, r""" linalg.svd(A, full_matrices=True, *, out=None) -> (Tensor, Tensor, Tensor) diff --git a/torch/overrides.py b/torch/overrides.py index 2ed5d3ec622..4a0ee2b8e20 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -730,6 +730,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1, torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1, torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1, + torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1, torch.norm_except_dim: lambda v, pow=2, dim=0: -1, torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1, torch.numel: lambda input: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 083defb70d6..20d76dc1e6d 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -545,6 +545,18 @@ def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad): return result +def sample_inputs_linalg_matrix_norm(op_info, device, dtype, requires_grad, **kwargs): + sizes = ((2, 2), (2, 3, 2)) + ords = ('fro', 'nuc', inf, -inf, 1, -1, 2, -2) + dims = ((-2, -1), (-1, 0)) + + inputs: List[SampleInput] = [] + for size, ord, dim, keepdim in product(sizes, ords, dims, [True, False]): + t = make_tensor(size, device, dtype, requires_grad=requires_grad) + inputs.append(SampleInput(t, args=(ord, dim, keepdim))) + + return inputs + def sample_inputs_linalg_norm(op_info, device, dtype, requires_grad): test_sizes = [ (S,), @@ -564,8 +576,6 @@ def sample_inputs_linalg_norm(op_info, device, dtype, requires_grad): inputs = [] - is_dtype_half = dtype in [torch.float16, torch.bfloat16] - for test_size in test_sizes: is_vector_norm = len(test_size) == 1 is_matrix_norm = len(test_size) == 2 @@ -4545,6 +4555,15 @@ op_db: List[OpInfo] = [ # linalg.norm does not correctly warn when resizing out= inputs SkipInfo('TestCommon', 'test_out'), )), + OpInfo('linalg.matrix_norm', + aten_name='linalg_matrix_norm', + dtypes=floating_and_complex_types(), + decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], + sample_inputs_func=sample_inputs_linalg_matrix_norm, + skips=( + # linalg.matrix_norm does not correctly warn when resizing out= inputs + SkipInfo('TestCommon', 'test_out'), + )), OpInfo('linalg.qr', aten_name='linalg_qr', op=torch.linalg.qr,