From d1bb8e828f280d1c66fff193c043d5bc36154577 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Thu, 26 Sep 2024 04:52:03 +0000 Subject: [PATCH] Add deterministic path for CUDA `cumsum` (#136224) Change `cumsum` to call its decomposition when `use_deterministic_algorithms(True)` and input is CUDA. Fixes #89492 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136224 Approved by: https://github.com/ezyang, https://github.com/justinchuby --- test/test_torch.py | 24 +++++++++++++++++- tools/pyi/gen_pyi.py | 1 + torch/__init__.py | 2 +- torch/_tensor.py | 31 +++++++++++++++++++++++ torch/_tensor_docs.py | 9 ------- torch/_torch_docs.py | 32 ------------------------ torch/functional.py | 58 +++++++++++++++++++++++++++++++++++++++++++ torch/overrides.py | 2 +- 8 files changed, 115 insertions(+), 44 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index ec7a02c44f6..3d5f0963c45 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1739,11 +1739,33 @@ else: 'embedding_bag_backward_cuda_max', torch.device(device).type == 'cuda') + @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") + @onlyCUDA + def test_deterministic_cumsum(self, device): + test_cases = [ + # size, dim + [(2, 3, 4), 0], + [(2, 3, 4), 1], + [(2, 3, 4), 2], + [(1000, 10, 2), 0], + ] + for size, dim in test_cases: + input = 100 * torch.randn(*size, device=device) + with DeterministicGuard(True): + res0 = input.cumsum(dim) + for _ in range(3): + res1 = input.cumsum(dim) + self.assertEqual(res0, res1, atol=0, rtol=0) + + res_cpu = input.cpu().cumsum(dim) + self.assertEqual(res0, res_cpu, atol=1e-3, rtol=1e-2) + + @dtypes(*all_types_and_complex_and(torch.bool)) @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") def test_nondeterministic_alert_cumsum(self, device, dtype): input = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9) - should_alert = torch.device(device).type == 'cuda' and (dtype.is_floating_point or dtype.is_complex) + should_alert = False for op_call in [torch.Tensor.cumsum, torch.cumsum]: self.check_nondeterministic_alert( diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index bd2fcee5e51..e425aac83b4 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -136,6 +136,7 @@ blocklist = [ "requires_grad", "range", # defined in functional + "cumsum", "einsum", # Somehow, these are defined in both _C and in functional. Ick! "broadcast_tensors", diff --git a/torch/__init__.py b/torch/__init__.py index 64bfab70838..9f13b2eedf3 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1235,6 +1235,7 @@ def use_deterministic_algorithms( * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU tensor * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor + * :func:`torch.cumsum` when called on a CUDA tensor * :func:`torch.gather` when called on a CUDA tensor that requires grad * :func:`torch.index_add` when called on CUDA tensor * :func:`torch.index_select` when attempting to differentiate a CUDA tensor @@ -1281,7 +1282,6 @@ def use_deterministic_algorithms( * :func:`torch.kthvalue` with called on a CUDA tensor * :func:`torch.median` with indices output when called on a CUDA tensor * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor - * :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex * :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor * :func:`torch.Tensor.resize_` when called with a quantized tensor diff --git a/torch/_tensor.py b/torch/_tensor.py index e22c6a92d92..b8c00c60565 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -846,6 +846,37 @@ class Tensor(torch._C.TensorBase): return _symeig(self, eigenvectors=eigenvectors) + def cumsum( + self, + dim=None, + *, + dtype=None, + out=None, + axis=None, + ): + r""" + cumsum(dim, dtype=None) -> Tensor + + See :func:`torch.cumsum` + """ + if axis is not None and dim is not None: + raise RuntimeError("expected either 'dim' or 'axis' to be given, not both") + elif axis is not None: + dim = axis + if has_torch_function_unary(self): + return handle_torch_function( + Tensor.cumsum, + (self,), + self, + dim, + dtype=dtype, + out=out, + ) + if out is None: + return torch.cumsum(self, dim, dtype=dtype) + else: + return torch.cumsum(self, dim, dtype=dtype, out=out) + def lu(self, pivot=True, get_infos=False): r"""See :func:`torch.lu`""" # If get_infos is True, then we don't need to check for errors and vice versa diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 1ee0548eb15..28f72d208f0 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1497,15 +1497,6 @@ In-place version of :meth:`~Tensor.cumprod` """, ) -add_docstr_all( - "cumsum", - r""" -cumsum(dim, dtype=None) -> Tensor - -See :func:`torch.cumsum` -""", -) - add_docstr_all( "cumsum_", r""" diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 5eeced5dec1..b27c077b3fb 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3317,38 +3317,6 @@ Example:: """.format(**reduceops_common_args), ) -add_docstr( - torch.cumsum, - r""" -cumsum(input, dim, *, dtype=None, out=None) -> Tensor - -Returns the cumulative sum of elements of :attr:`input` in the dimension -:attr:`dim`. - -For example, if :attr:`input` is a vector of size N, the result will also be -a vector of size N, with elements. - -.. math:: - y_i = x_1 + x_2 + x_3 + \dots + x_i - -Args: - {input} - dim (int): the dimension to do the operation over - -Keyword args: - {dtype} - {out} - -Example:: - - >>> a = torch.randint(1, 20, (10,)) - >>> a - tensor([13, 7, 3, 10, 13, 3, 15, 10, 9, 10]) - >>> torch.cumsum(a, dim=0) - tensor([13, 20, 23, 33, 46, 49, 64, 74, 83, 93]) -""".format(**reduceops_common_args), -) - add_docstr( torch.count_nonzero, r""" diff --git a/torch/functional.py b/torch/functional.py index 9180262708e..4d922480947 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import importlib import itertools import operator from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union @@ -28,6 +29,7 @@ __all__ = [ "block_diag", "cdist", "chain_matmul", + "cumsum", "einsum", "istft", "lu", @@ -2035,6 +2037,62 @@ def chain_matmul(*matrices, out=None): return _VF.chain_matmul(matrices, out=out) # type: ignore[attr-defined] +def cumsum( + self: Tensor, + dim: Optional[int] = None, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, + axis: Optional[int] = None, +): + r""" + cumsum(input, dim, *, dtype=None, out=None) -> Tensor + + Returns the cumulative sum of elements of :attr:`input` in the dimension + :attr:`dim`. + + For example, if :attr:`input` is a vector of size N, the result will also be + a vector of size N, with elements. + + .. math:: + y_i = x_1 + x_2 + x_3 + \dots + x_i + + Args: + input (Tensor): the input tensor. + dim (int): the dimension to do the operation over + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + If specified, the input tensor is casted to :attr:`dtype` before the operation + is performed. This is useful for preventing data type overflows. Default: None. + out (Tensor, optional): the output tensor. + + Example:: + + >>> torch.manual_seed(0) + >>> a = torch.randint(1, 20, (10,)) + >>> a + tensor([16, 5, 1, 1, 12, 8, 6, 10, 10, 5]) + >>> torch.cumsum(a, dim=0) + tensor([16, 21, 22, 23, 35, 43, 49, 59, 69, 74]) + """ + if axis is not None: + if dim is None: + dim = axis + else: + raise RuntimeError("expected either 'dim' or 'axis' to be given, not both") + if has_torch_function_unary(self): + return handle_torch_function(cumsum, (self,), self, dim, dtype=dtype, out=out) + if not torch.jit.is_scripting(): + if torch.are_deterministic_algorithms_enabled() and self.is_cuda: + ref_func = importlib.import_module("torch._refs").cumsum + return ref_func(self, dim, dtype=dtype, out=out) + if out is None: + return _VF.cumsum(self, dim, dtype=dtype) # type: ignore[attr-defined] + else: + return _VF.cumsum(self, dim, dtype=dtype, out=out) # type: ignore[attr-defined] + + def _lu_impl(A, pivot=True, get_infos=False, out=None): # type: (Tensor, bool, bool, Any) -> Tuple[Tensor, Tensor, Tensor] r"""Computes the LU factorization of a matrix or batches of matrices diff --git a/torch/overrides.py b/torch/overrides.py index 7a568d7e22c..a638ccece63 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -553,7 +553,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.cummax: lambda input, dim, out=None: -1, torch.cummin: lambda input, dim, out=None: -1, torch.cumprod: lambda input, dim, out=None, dtype=None: -1, - torch.cumsum: lambda input, dim, out=None, dtype=None: -1, + torch.cumsum: lambda input, dim, out=None, dtype=None, axis=None: -1, torch.cumulative_trapezoid: lambda y, x=None, dim=-1: -1, torch.logcumsumexp: lambda input, dim, out=None: -1, torch.deg2rad: lambda input, out=None: -1,