From f87f753bb997b2da82f7d2a561ccb40ab4f6bd9d Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Mon, 14 Feb 2022 11:39:05 -0800 Subject: [PATCH] avoiding adding some functions to the public python API before 1.11 release (#72543) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72543 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D34085724 Pulled By: bdhirsh fbshipit-source-id: 941d5a90a6fa5328268d623e0e2b01577e4132ca (cherry picked from commit 6676a0c79a3b2bc1aa95e09e91eb92a6eca6b764) --- aten/src/ATen/native/native_functions.yaml | 2 +- .../check_forward_backward_compatibility.py | 1 + test/test_torch.py | 8 +++--- tools/autograd/derivatives.yaml | 2 +- torch/nn/init.py | 25 +++++++++++-------- torch/overrides.py | 2 +- .../_internal/common_methods_invocations.py | 2 +- 7 files changed, 23 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a84eddacbb7..7c252141099 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6061,7 +6061,7 @@ - func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor variants: function, method -- func: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor +- func: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor variants: function, method dispatch: CPU: scatter_reduce_two_cpu diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index e15ac0f29bc..2297dec5c2f 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -106,6 +106,7 @@ ALLOW_LIST = [ ("aten::_scatter_reduce", datetime.date(2022, 1, 31)), ("aten::native_multi_head_self_attention", datetime.date(9999, 1, 1)), ("aten::_native_multi_head_self_attention", datetime.date(9999, 1, 1)), + ("aten::scatter_reduce.two", datetime.date(2022, 3, 15)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/test_torch.py b/test/test_torch.py index d06dfb97a54..e2422d1477d 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5773,7 +5773,7 @@ class TestTorch(TestCase): for reduce in reduces: for dim in range(len(shape)): - output = input.scatter_reduce(dim, index, reduce, output_size=output_size) + output = input._scatter_reduce(dim, index, reduce, output_size=output_size) # Check that output is of the correct size output_shape = copy.copy(shape) @@ -5807,16 +5807,16 @@ class TestTorch(TestCase): self.assertTrue(torch.allclose(output, expected)) with self.assertRaisesRegex(RuntimeError, "Expected `dim` to be in range -3 to 2"): - torch.scatter_reduce(input, 4, index, "sum") + torch._scatter_reduce(input, 4, index, "sum") with self.assertRaisesRegex(RuntimeError, "Shape mismatch"): index2 = torch.randint(0, output_size, (10, ), dtype=torch.long, device=device) - torch.scatter_reduce(input, 0, index2, "sum") + torch._scatter_reduce(input, 0, index2, "sum") with self.assertRaisesRegex(RuntimeError, "Expected `index` values to be in range 0 to 2"): input2 = torch.randn(10, dtype=dtype, device=device) index2 = torch.tensor([0, 1, 0, 1, 2, 3, 3, 4, 4, 3]) - torch.scatter_reduce(input2, 0, index2, "sum", output_size=2) + torch._scatter_reduce(input2, 0, index2, "sum", output_size=2) def test_structseq_repr(self): a = torch.arange(250).reshape(5, 5, 10) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 7f7c13f01aa..27e4007d569 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2595,6 +2595,6 @@ - name: _efficientzerotensor(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor output_differentiability: [False] -- name: scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor +- name: _scatter_reduce.two(Tensor self, int dim, Tensor index, str reduce, *, int? output_size=None) -> Tensor self: scatter_reduce_backward(grad, self, dim, index, reduce, result) index: non_differentiable diff --git a/torch/nn/init.py b/torch/nn/init.py index 357fb7498c5..ce83137845f 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -4,9 +4,6 @@ import warnings from torch import Tensor import torch -from ..overrides import ( - has_torch_function_variadic, - handle_torch_function) # These no_grad_* functions are necessary as wrappers around the parts of these # functions that use `with torch.no_grad()`. The JIT doesn't support context @@ -135,8 +132,8 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor: >>> w = torch.empty(3, 5) >>> nn.init.uniform_(w) """ - if has_torch_function_variadic(tensor): - return handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b) + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b) return _no_grad_uniform_(tensor, a, b) @@ -153,8 +150,8 @@ def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor: >>> w = torch.empty(3, 5) >>> nn.init.normal_(w) """ - if has_torch_function_variadic(tensor): - return handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std) + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std) return _no_grad_normal_(tensor, mean, std) def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor: @@ -190,8 +187,8 @@ def constant_(tensor: Tensor, val: float) -> Tensor: >>> w = torch.empty(3, 5) >>> nn.init.constant_(w, 0.3) """ - if has_torch_function_variadic(tensor): - return handle_torch_function(constant_, (tensor,), tensor=tensor, val=val) + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val) return _no_grad_fill_(tensor, val) @@ -393,8 +390,14 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): >>> w = torch.empty(3, 5) >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') """ - if has_torch_function_variadic(tensor): - return handle_torch_function(kaiming_uniform_, (tensor,), tensor=tensor, a=a, mode=mode, nonlinearity=nonlinearity) + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + kaiming_uniform_, + (tensor,), + tensor=tensor, + a=a, + mode=mode, + nonlinearity=nonlinearity) if 0 in tensor.shape: warnings.warn("Initializing zero-element tensors is a no-op") diff --git a/torch/overrides.py b/torch/overrides.py index 76a5fe67069..c8ef49e7b9d 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -897,7 +897,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1, torch.scatter: lambda input, dim, index, src: -1, torch.scatter_add: lambda input, dim, index, src: -1, - torch.scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1, + torch._scatter_reduce: lambda input, dim, index, reduce, output_size=None: -1, torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1, torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1, torch.select: lambda input, dim, index: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index c095f6a8523..45c06edb9a3 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15517,7 +15517,7 @@ op_db: List[OpInfo] = [ supports_fwgrad_bwgrad=True, ), OpInfo( - 'scatter_reduce', + '_scatter_reduce', dtypes=all_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_scatter_reduce, supports_out=False,