From e8dc34eaebd7fbb6433c6e019b50d099ea505a8b Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Thu, 16 Feb 2023 01:13:08 +0000 Subject: [PATCH] [MPS] Move max_pool2d to mps dispatch key (#90772) Related issue: #77394 This PR also modifies some assertions in the codegen, an explanatory comment for it has been added. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90772 Approved by: https://github.com/albanD --- aten/src/ATen/native/Pooling.cpp | 7 ------- aten/src/ATen/native/mps/operations/Pooling.mm | 2 +- aten/src/ATen/native/native_functions.yaml | 13 ++++--------- .../HasDecompTest.test_has_decomposition.expect | 6 ++---- .../check_forward_backward_compatibility.py | 4 ++++ tools/autograd/derivatives.yaml | 4 ++-- torchgen/model.py | 11 ++++++++++- 7 files changed, 23 insertions(+), 24 deletions(-) diff --git a/aten/src/ATen/native/Pooling.cpp b/aten/src/ATen/native/Pooling.cpp index fcbe741ab0e..24e813a485a 100644 --- a/aten/src/ATen/native/Pooling.cpp +++ b/aten/src/ATen/native/Pooling.cpp @@ -9,7 +9,6 @@ #include #include #else -#include #include #include #include @@ -141,12 +140,6 @@ Tensor max_pool2d( return at::mkldnn_max_pool2d( self, kernel_size, stride, padding, dilation, ceil_mode); } -#ifdef USE_MPS - if (self.is_mps()) { - return at::_mps_max_pool2d( - self, kernel_size, stride, padding, dilation, ceil_mode); - } -#endif #if defined(C10_MOBILE) if(xnnpack::use_max_pool2d(self, kernel_size, padding, stride, dilation, ceil_mode)) { diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index 08727fed826..ff26ff83518 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -308,7 +308,7 @@ static void avg_pool2d_template(const Tensor& input, const Tensor& output, } // namespace mps -Tensor _mps_max_pool2d( +Tensor mps_max_pool2d( const Tensor& input, IntArrayRef kernel_size, IntArrayRef stride, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 2cae01f109d..23f40e27c44 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3567,19 +3567,14 @@ - func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor - func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor - -# TODO: Add this function to MPS dispatch key so that we avoid declaring it in -# native_functions.yaml -# https://github.com/pytorch/pytorch/issues/77394 -- func: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor dispatch: - MPS: _mps_max_pool2d - autogen: _mps_max_pool2d.out + CompositeImplicitAutograd: max_pool2d + MPS: mps_max_pool2d -- func: mps_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor +- func: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor dispatch: MPS: mps_max_pool2d_backward - autogen: mps_max_pool2d_backward.out + autogen: max_pool2d_backward.out - func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor dispatch: diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index ed52d371ca5..49db57b3e04 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -377,8 +377,6 @@ aten::_mps_convolution aten::_mps_convolution.out aten::_mps_convolution_transpose aten::_mps_convolution_transpose.out -aten::_mps_max_pool2d -aten::_mps_max_pool2d.out aten::_native_batch_norm_legit.no_stats_out aten::_native_batch_norm_legit.out aten::_native_decoder_only_multi_head_attention @@ -857,6 +855,8 @@ aten::max aten::max.dim aten::max.dim_max aten::max.unary_out +aten::max_pool2d_backward +aten::max_pool2d_backward.out aten::max_pool2d_with_indices aten::max_pool2d_with_indices.out aten::max_pool2d_with_indices_backward @@ -930,8 +930,6 @@ aten::mps_convolution_backward aten::mps_convolution_backward.out aten::mps_convolution_transpose_backward aten::mps_convolution_transpose_backward.out -aten::mps_max_pool2d_backward -aten::mps_max_pool2d_backward.out aten::multi_margin_loss aten::multi_margin_loss.out aten::multi_margin_loss_backward diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index bca79d85425..ef51743c929 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -150,6 +150,10 @@ ALLOW_LIST = [ ("aten::sum.SymInt", datetime.date(2022, 11, 30)), ("aten::mps_linear", datetime.date(9999, 1, 1)), ("aten::_mps_linear", datetime.date(9999, 1, 1)), + ("aten::_mps_max_pool2d", datetime.date(9999, 1, 1)), + ("aten::_mps_max_pool2d.out", datetime.date(9999, 1, 1)), + ("aten::mps_max_pool2d_backward", datetime.date(9999, 1, 1)), + ("aten::mps_max_pool2d_backward.out", datetime.date(9999, 1, 1)), ("aten::view_copy.SymInt", datetime.date(2022, 11, 30)), ("aten::view_copy.SymInt_out", datetime.date(2022, 11, 30)), ("aten::expand_copy.SymInt", datetime.date(2022, 11, 30)), diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 1c2bfd4b2b8..d377abe59a4 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2170,8 +2170,8 @@ input, weight, bias: linear_backward(input, grad, weight, grad_input_mask) #mps -- name: _mps_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor - self: mps_max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode) +- name: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode) - name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor self, weight, bias: "grad.defined() ? mps_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple()" diff --git a/torchgen/model.py b/torchgen/model.py index e6897ded472..75f2b089232 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -638,6 +638,7 @@ class NativeFunction: raw_dispatch = e.pop("dispatch", None) assert raw_dispatch is None or isinstance(raw_dispatch, dict), e dispatch: Dict[DispatchKey, BackendMetadata] = {} + num_dispatch_keys: int = 0 if raw_dispatch is not None: assert not manual_kernel_registration, ( "cannot specify both manual_kernel_registration and dispatch; with " @@ -650,6 +651,8 @@ class NativeFunction: assert isinstance(ks, str), e for k in ks.split(","): dispatch_key = DispatchKey.parse(k.strip()) + num_dispatch_keys += 1 + if ignore_keys and dispatch_key in ignore_keys: continue assert dispatch_key in dispatch_keys, ( @@ -677,7 +680,12 @@ class NativeFunction: ): redundant_composite_implicit_autograd = True - assert not (len(dispatch) == 1 and redundant_composite_implicit_autograd), ( + # We count the number of dispatch keys which have not been ignored to prevent a dispatch table + # in which all backend keys are ignored but necessarily kept, remaining compositeimplicit, + # from being treated as redundant. + assert not ( + num_dispatch_keys == 1 and redundant_composite_implicit_autograd + ), ( "unnecessary dispatch table for this function; just delete the dispatch " "key entirely" ) @@ -687,6 +695,7 @@ class NativeFunction: structured_delegate or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint() + or num_dispatch_keys != 1 ), ( f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} " f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected "