diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index c6f0412c9d0..513c6da601d 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -cc19c3abcbb3f702d5f468ee08549edd926ef549 +2bdd718b4b7309b5868825e261ae05bef6be548f diff --git a/aten/src/ATen/BatchingRegistrations.cpp b/aten/src/ATen/BatchingRegistrations.cpp index e4b5f7028d2..f52af4ce71d 100644 --- a/aten/src/ATen/BatchingRegistrations.cpp +++ b/aten/src/ATen/BatchingRegistrations.cpp @@ -56,18 +56,21 @@ static bool is_allowed_dim_on_scalar_tensor(int64_t dim) { return dim == 0 || dim == -1; } -Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, optional dtype) { - // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail - // and instead returns a new scalar tensor (this also happens for dim=-1) - // If the following happens: - // >>> x = torch.randn(B0) # the per-examples are all scalars - // >>> vmap(partial(torch.sum, dim=0), x) - // then we replicate the behavior of sum(scalar_tensor, dim=0). - if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) { - return self.clone(); +Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, optional dtype) { + if (opt_dims.has_value()) { + auto dims = opt_dims.value(); + // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail + // and instead returns a new scalar tensor (this also happens for dim=-1) + // If the following happens: + // >>> x = torch.randn(B0) # the per-examples are all scalars + // >>> vmap(partial(torch.sum, dim=0), x) + // then we replicate the behavior of sum(scalar_tensor, dim=0). + if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) { + return self.clone(); + } } auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); - auto dims_physical = self_physical.getPhysicalDims(dims); + auto dims_physical = self_physical.getPhysicalDims(opt_dims); auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype); return self_physical.getPhysicalToLogicalMap().apply(result); } diff --git a/aten/src/ATen/VmapTransforms.cpp b/aten/src/ATen/VmapTransforms.cpp index 4bda903545f..20c792f7370 100644 --- a/aten/src/ATen/VmapTransforms.cpp +++ b/aten/src/ATen/VmapTransforms.cpp @@ -55,13 +55,20 @@ int64_t VmapPhysicalView::numLogicalDims() const { return /*physical*/tensor_.dim() - numBatchDims(); } -VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const { +VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const { auto logical_ndim = numLogicalDims(); // NB: fmap doesn't have a SmallVector variant, so we don't use it here. VmapDimVector result; result.reserve(logical_ndim); - for (auto dim : logical_dims) { - result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims()); + if (opt_logical_dims.has_value()) { + auto logical_dims = opt_logical_dims.value(); + for (auto dim : logical_dims) { + result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims()); + } + } else { + for (int64_t dim = 0; dim < logical_ndim; dim++) { + result.push_back(dim + numBatchDims()); + } } return result; } diff --git a/aten/src/ATen/VmapTransforms.h b/aten/src/ATen/VmapTransforms.h index 89190265a95..53e476e2243 100644 --- a/aten/src/ATen/VmapTransforms.h +++ b/aten/src/ATen/VmapTransforms.h @@ -131,7 +131,7 @@ struct TORCH_API VmapPhysicalView { // This is because the size of levels tell us that the first two dimensions // of `tensor_` are batch dimensions, so a logical dim of `n` is actually // a physical dim of `n + 2`. - VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const; + VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const; int64_t getPhysicalDim(int64_t logical_dim) const; // Returns a VmapPhysicalToLogicalMap object. This can be used for diff --git a/aten/src/ATen/WrapDimUtilsMulti.h b/aten/src/ATen/WrapDimUtilsMulti.h index 64975e9a43c..c1899bea872 100644 --- a/aten/src/ATen/WrapDimUtilsMulti.h +++ b/aten/src/ATen/WrapDimUtilsMulti.h @@ -14,7 +14,7 @@ namespace at { constexpr size_t dim_bitset_size = 64; static inline std::bitset dim_list_to_bitset( - IntArrayRef dims, + OptionalIntArrayRef opt_dims, int64_t ndims) { TORCH_CHECK( ndims <= (int64_t)dim_bitset_size, @@ -22,11 +22,21 @@ static inline std::bitset dim_list_to_bitset( dim_bitset_size, " dims are supported"); std::bitset seen; - for (const auto i : c10::irange(dims.size())) { - size_t dim = maybe_wrap_dim(dims[i], ndims); - TORCH_CHECK( - !seen[dim], "dim ", dim, " appears multiple times in the list of dims"); - seen[dim] = true; + if (opt_dims.has_value()) { + auto dims = opt_dims.value(); + for (const auto i : c10::irange(dims.size())) { + size_t dim = maybe_wrap_dim(dims[i], ndims); + TORCH_CHECK( + !seen[dim], + "dim ", + dim, + " appears multiple times in the list of dims"); + seen[dim] = true; + } + } else { + for (int64_t dim = 0; dim < ndims; dim++) { + seen[dim] = true; + } } return seen; } diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 1ebcf775a54..af91c81561e 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -455,7 +455,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { // KERNEL(ADD_NS(norm), "norm.ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional, IntArrayRef, bool, ScalarType), fp32_set_opt_dtype) // KERNEL(ADD_NS(norm), "norm.names_ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional, DimnameList, bool, ScalarType), fp32_set_opt_dtype) KERNEL(ADD_NS(sum), "sum", Tensor (const Tensor &, c10::optional), fp32_set_opt_dtype) - KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, IntArrayRef, bool, c10::optional), fp32_set_opt_dtype) + KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, OptionalIntArrayRef, bool, c10::optional), fp32_set_opt_dtype) KERNEL(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional), fp32_set_opt_dtype) // fp32_append_dtype // The fp32_append_dtype wrapper overrides implicit promotion behavior. diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 16b881e54d0..cced7d69660 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -52,8 +52,6 @@ namespace meta { static ScalarType infer_dtype_from_optional( const Tensor& self, - IntArrayRef dim, - bool keepdim, const optional& opt_dtype, const Tensor& result) { // 'opt_dtype' has the priority for both cases. @@ -187,9 +185,9 @@ TORCH_META_FUNC(cumprod) } TORCH_META_FUNC2(sum, dim_IntList) -(const Tensor& self, IntArrayRef dim, bool keepdim, optional opt_dtype) { - auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output()); - resize_reduction(*this, self, dim, keepdim, out_dtype); +(const Tensor& self, OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype) { + auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output()); + resize_reduction(*this, self, opt_dim, keepdim, out_dtype); } TORCH_META_FUNC2(prod, dim_int) @@ -197,7 +195,7 @@ TORCH_META_FUNC2(prod, dim_int) int64_t dim, bool keepdim, c10::optional dtype) { - auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, dtype, maybe_get_output()); + auto out_dtype = infer_dtype_from_optional(self, dtype, maybe_get_output()); resize_reduction(*this, self, dim, keepdim, out_dtype); } @@ -221,7 +219,7 @@ TORCH_META_FUNC2(mean, dim) "Got: ", dtype); } - auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output()); + auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output()); resize_reduction(*this, self, dim, keepdim, out_dtype); } @@ -1061,11 +1059,11 @@ inline ScalarType get_dtype_from_result(Tensor& result, optional dty TORCH_IMPL_FUNC(sum_out) (const Tensor& self, - IntArrayRef dim, + OptionalIntArrayRef opt_dim, bool keepdim, optional opt_dtype, const Tensor& result) { - auto iter = meta::make_reduction_from_out_ty(self, result, dim, keepdim, result.scalar_type()); + auto iter = meta::make_reduction_from_out_ty(self, result, opt_dim, keepdim, result.scalar_type()); if (iter.numel() == 0) { result.zero_(); } else { diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h index 8f63b9bd0b6..7c73c85d4c2 100644 --- a/aten/src/ATen/native/ReduceOpsUtils.h +++ b/aten/src/ATen/native/ReduceOpsUtils.h @@ -110,12 +110,27 @@ static inline Tensor integer_upcast(const Tensor& self, optional dty using DimMask = TensorIterator::DimMask; -static DimMask make_dim_mask(IntArrayRef dims, int64_t ndim) { - DimMask mask; - if (dims.empty()) { - mask = DimMask().flip(); +static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) { + if (opt_dims.has_value()) { + return DimVector(opt_dims.value()); } else { - mask = at::dim_list_to_bitset(dims, ndim); + std::vector all_dims(ndim); + std::iota(all_dims.begin(), all_dims.end(), 0); + return DimVector(all_dims); + } +} + +static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim) { + DimMask mask; + if (opt_dims.has_value()) { + auto dims = opt_dims.value(); + if (dims.empty()) { + mask = DimMask().flip(); + } else { + mask = at::dim_list_to_bitset(dims, ndim); + } + } else { + mask = DimMask().flip(); } return mask; } @@ -320,10 +335,10 @@ static C10_UNUSED DimVector get_reduction_shape( static void resize_reduction( impl::MetaBase& meta, const Tensor& self, - IntArrayRef dims, + OptionalIntArrayRef opt_dims, bool keepdim, ScalarType out_dtype) { - DimVector dims_(dims); + DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim()); maybe_wrap_dims(dims_, self.dim()); auto shape = get_reduction_shape(self, dims_, keepdim); meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype)); @@ -351,11 +366,11 @@ static void resize_reduction_with_indices( static TensorIterator make_reduction( const Tensor& self, const Tensor& result, - IntArrayRef dims, + OptionalIntArrayRef opt_dims, bool keepdim, ScalarType in_dtype) { int64_t ndim = self.dim(); - auto mask = at::native::make_dim_mask(dims, ndim); + auto mask = at::native::make_dim_mask(opt_dims, ndim); auto viewed_result = at::native::review_reduce_result(result, ndim, mask, keepdim); if (self.scalar_type() == in_dtype) { @@ -389,7 +404,7 @@ static TensorIterator make_reduction( static C10_UNUSED TensorIterator make_reduction_from_out_ty( const Tensor& self, const Tensor& result, - IntArrayRef dims, + OptionalIntArrayRef opt_dims, bool keepdim, ScalarType out_dtype) { // special case for type promotion in mixed precision, improves computational @@ -401,7 +416,7 @@ static C10_UNUSED TensorIterator make_reduction_from_out_ty( (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && out_dtype == kFloat); auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype; - return make_reduction(self, result, dims, keepdim, in_dtype); + return make_reduction(self, result, opt_dims, keepdim, in_dtype); } } // namespace meta diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 51339167db7..35c618ed9a6 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -84,13 +84,15 @@ void set_apparent_shapes(NSMutableArray * &apparent_out_shape, // Helper function to set the axes of reduction void set_axes(NSMutableArray * &axes, int64_t num_reduce_dims, - IntArrayRef& dim, + OptionalIntArrayRef opt_dim, int64_t num_input_dims) { if(num_reduce_dims == 0) { axes = [NSMutableArray arrayWithCapacity:1]; axes[0] = @0; } else { + TORCH_INTERNAL_ASSERT(opt_dim.has_value()); + IntArrayRef dim = opt_dim.value(); axes = [NSMutableArray arrayWithCapacity:num_reduce_dims]; for(int i = 0; i < num_reduce_dims; i++) { axes[i] = [NSNumber numberWithInt:maybe_wrap_dim(dim[i], num_input_dims)]; @@ -100,7 +102,7 @@ void set_axes(NSMutableArray * &axes, // Helper function to prepare axes and tensor shapes void set_axes_and_shapes(const Tensor& input_t, - IntArrayRef dims, + OptionalIntArrayRef opt_dims, NSMutableArray * &axes, NSMutableArray * &apparent_input_shape, NSMutableArray * &apparent_output_shape, @@ -109,13 +111,13 @@ void set_axes_and_shapes(const Tensor& input_t, IntArrayRef input_shape = input_t.sizes(); int64_t num_input_dims = input_shape.size(); - int64_t num_reduce_dims = dims.size(); + int64_t num_reduce_dims = opt_dims.has_value() ? opt_dims.value().size() : 0; int64_t num_output_dims; num_output_dims = num_reduce_dims == 0 ? 1 : num_input_dims; // Reduction axes - set_axes(axes, num_reduce_dims, dims, input_shape.size()); + set_axes(axes, num_reduce_dims, opt_dims, input_shape.size()); // Shapes set_apparent_shapes(apparent_output_shape, @@ -137,7 +139,7 @@ void set_axes_and_shapes(const Tensor& input_t, void reduction_out_mps (const Tensor& input_t, - IntArrayRef dim, + OptionalIntArrayRef opt_dim, bool keepdim, c10::optional dtype, const Tensor& output_t, @@ -146,10 +148,13 @@ void reduction_out_mps IntArrayRef input_shape = input_t.sizes(); - for(int i = 0; i < dim.size(); i++) { - auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size()); - TORCH_CHECK(wrap_dim < input_shape.size(), - func_name+": reduction dim must be in the range of input shape") + if (opt_dim.has_value()) { + IntArrayRef dim = opt_dim.value(); + for(int i = 0; i < dim.size(); i++) { + auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size()); + TORCH_CHECK(wrap_dim < input_shape.size(), + func_name+": reduction dim must be in the range of input shape") + } } namespace native_mps = at::native::mps; @@ -159,7 +164,7 @@ void reduction_out_mps NSMutableArray *apparent_output_shape = nil; NSMutableArray *output_shape = nil; - set_axes_and_shapes(input_t, dim, axes, apparent_input_shape, apparent_output_shape, output_shape); + set_axes_and_shapes(input_t, opt_dim, axes, apparent_input_shape, apparent_output_shape, output_shape); auto cache_ = native_mps::MPSGraphCache::getInstance(); @@ -271,12 +276,12 @@ void reduction_out_mps TORCH_IMPL_FUNC(sum_out_mps) (const Tensor& input_t, - IntArrayRef dim, + OptionalIntArrayRef opt_dim, bool keepdim, c10::optional dtype, const Tensor& output_t) { - reduction_out_mps(input_t, dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps"); + reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps"); } TORCH_IMPL_FUNC(prod_out_mps) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 67f149e1858..bbbf3d22f40 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4563,7 +4563,7 @@ CompositeExplicitAutograd: sum SparseCsrCPU, SparseCsrCUDA: sum_csr -- func: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor structured_delegate: sum.IntList_out device_check: NoCheck # TensorIterator variants: function, method @@ -4572,7 +4572,7 @@ device_check: NoCheck # TensorIterator variants: function, method -- func: sum.IntList_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) +- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) structured: True device_check: NoCheck # TensorIterator dispatch: diff --git a/c10/util/OptionalArrayRef.h b/c10/util/OptionalArrayRef.h index 7ca375d7cb7..ff51f549a56 100644 --- a/c10/util/OptionalArrayRef.h +++ b/c10/util/OptionalArrayRef.h @@ -74,6 +74,9 @@ class OptionalArrayRef final { Args&&... args) : wrapped_opt_array_ref(ip, il, args...) {} + constexpr OptionalArrayRef(const std::initializer_list& Vec) + : wrapped_opt_array_ref(ArrayRef(Vec)) {} + // Destructor ~OptionalArrayRef() = default; diff --git a/test/test_jit.py b/test/test_jit.py index 222150e4944..b4f8355f68c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5368,7 +5368,7 @@ a") def func2(x): return x.sum(dim=4) - # test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument + # test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument self.run_pass('constant_propagation', func.graph) self.run_pass('constant_propagation', func2.graph) g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False) diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index 2b9b00191cc..5f7fd37f906 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -1195,8 +1195,11 @@ class TestNamedTensor(TestCase): check_output(op(t, 1), ['N', 'L']) check_output(op(t, -1), ['N', 'C']) check_output(op(t, 'C'), ['N', 'L']) - with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): - op(t, None) + if op.__name__ in ['sum']: + check_output(op(t, None), []) + else: + with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): + op(t, None) with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'): op(t, 'H') diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index f7bc820b326..8df3fecaff1 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1522,7 +1522,7 @@ self: grad.expand(self.sizes()) result: auto_linear -- name: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor +- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor self: sum_backward(grad, self.sizes(), dim, keepdim) result: auto_linear diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 72a1649e7a2..f7a4094de80 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -50,6 +50,9 @@ If :attr:`keepdim` is ``True``, the output tensor is of the same size as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1. Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the output tensor having 1 (or ``len(dim)``) fewer dimension(s). +"""}, {'opt_dim': """ + dim (int or tuple of ints, optional): the dimension or dimensions to reduce. + If ``None``, all dimensions are reduced. """}) single_dim_common = merge_dicts(reduceops_common_args, parse_kwargs(""" @@ -9651,7 +9654,7 @@ reduce over all of them. Args: {input} - {dim} + {opt_dim} {keepdim} Keyword args: diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index eff069ce9d1..77edee8d028 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -38,6 +39,7 @@ namespace details { using at::areAnyTensorSubclassLike; using at::IntArrayRef; +using at::OptionalIntArrayRef; using at::Scalar; using at::Tensor; using at::TensorList; @@ -556,35 +558,41 @@ Tensor deg2rad_backward(const Tensor& grad) { return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_PI_180))); } -Tensor unsqueeze_multiple(const Tensor& t, IntArrayRef dim, size_t n_dims) { - auto dim_size = dim.size(); - // Optimisation for two common cases - if (dim_size == 0) { - return t; - } else if (dim_size == 0) { - return t.unsqueeze(dim[0]); - } else { - auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims); - Tensor res = t; - for (const auto i : c10::irange(n_dims)) { - if (dims_to_unsqueeze[i]) { - res = res.unsqueeze(i); - } +Tensor unsqueeze_multiple( + const Tensor& t, + OptionalIntArrayRef opt_dim, + size_t n_dims) { + if (opt_dim.has_value()) { + IntArrayRef dim = opt_dim.value(); + auto dim_size = dim.size(); + // Optimisation for two common cases + if (dim_size == 0) { + return t; + } else if (dim_size == 1) { + return t.unsqueeze(dim[0]); } - return res; } + auto dims_to_unsqueeze = at::dim_list_to_bitset(opt_dim, n_dims); + Tensor res = t; + for (const auto i : c10::irange(n_dims)) { + if (dims_to_unsqueeze[i]) { + res = res.unsqueeze(i); + } + } + return res; } Tensor sum_backward( const Tensor& grad, IntArrayRef sizes, - IntArrayRef dims, + OptionalIntArrayRef opt_dims, bool keepdim) { - if (!keepdim && sizes.size() > 0 && dims.size() > 0) { - return unsqueeze_multiple(grad, dims, sizes.size()).expand(sizes); - } else { - return grad.expand(sizes); + if (!keepdim && sizes.size() > 0) { + if (opt_dims.has_value() && opt_dims.value().size() > 0) { + return unsqueeze_multiple(grad, opt_dims, sizes.size()).expand(sizes); + } } + return grad.expand(sizes); } Tensor nansum_backward( diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 061ed7d29a4..c22557c2b10 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -150,12 +150,12 @@ at::Tensor rad2deg_backward(const at::Tensor& grad); at::Tensor deg2rad_backward(const at::Tensor& grad); at::Tensor unsqueeze_multiple( const at::Tensor& t, - at::IntArrayRef dim, + at::OptionalIntArrayRef opt_dim, size_t n_dims); at::Tensor sum_backward( const at::Tensor& grad, at::IntArrayRef sizes, - at::IntArrayRef dims, + at::OptionalIntArrayRef opt_dims, bool keepdim); at::Tensor nansum_backward( const at::Tensor& grad, diff --git a/torch/csrc/jit/codegen/cuda/parser.cpp b/torch/csrc/jit/codegen/cuda/parser.cpp index 15e6e5737c2..eb8918fea75 100644 --- a/torch/csrc/jit/codegen/cuda/parser.cpp +++ b/torch/csrc/jit/codegen/cuda/parser.cpp @@ -2480,7 +2480,7 @@ class IrParser { { auto ptr_op = getOperatorForLiteral( - "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"); + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"); REGISTER_PARSE_RULE( ptr_op, { @@ -3857,7 +3857,7 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) { static auto reduction_operator_schema = getOperatorForLiteral( - "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)") + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)") ->schema(); if (node->matches(reduction_operator_schema)) { switch (offset) { diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index a4a307b51b8..f1021a552b2 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -1980,7 +1980,7 @@ class ShapePropagator : public PropertyPropBase { return true; } else if ( node->matches( - "aten::sum(Tensor self, int[] dim, bool keepdim, *, int? dtype) -> Tensor", + "aten::sum(Tensor self, int[]? dim, bool keepdim, *, int? dtype) -> Tensor", /*const_inputs=*/{attr::dim, attr::keepdim})) { auto& tp = tensor_types.at(0); auto sizes = tp->sizes().concrete_sizes().value(); diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index 8f81c8189e6..fe63de1a04d 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -94,7 +94,7 @@ bool isSupported(Node* node) { static const OperatorSet supported_reduction_set{ "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", - "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor", "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", }; diff --git a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp index 47e8b321512..2cc7fd0396b 100644 --- a/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp +++ b/torch/csrc/jit/runtime/serialized_shape_function_registry.cpp @@ -2158,6 +2158,54 @@ def transpose(self: List[int], _4 = torch.append(out, self[idx]) return out +)=====") ++ std::string(R"=====(def sum_dim(self: List[int], + opt_dims: Optional[List[int]], + keep_dim: bool, + dt: Any) -> List[int]: + out = annotate(List[int], []) + if opt_dims is None: + dims:List[int] = [] + else: + dims = opt_dims + for idx in range(torch.len(self)): + is_mean_dim = False + for _0 in range(torch.len(dims)): + reduce_dim = dims[_0] + _1 = torch.len(self) + if torch.le(_1, 0): + dim_post_expr = 1 + else: + dim_post_expr = _1 + min = torch.neg(dim_post_expr) + max = torch.sub(dim_post_expr, 1) + if torch.lt(reduce_dim, min): + _2 = True + else: + _2 = torch.gt(reduce_dim, max) + if torch.__not__(_2): + pass + else: + ops.prim.RaiseException("AssertionError: ") + if torch.lt(reduce_dim, 0): + dim0 = torch.add(reduce_dim, dim_post_expr) + dim = dim0 + else: + dim = reduce_dim + if torch.eq(idx, dim): + is_mean_dim0 = True + else: + is_mean_dim0 = is_mean_dim + is_mean_dim = is_mean_dim0 + if is_mean_dim: + if keep_dim: + _3 = torch.append(out, 1) + else: + pass + else: + _4 = torch.append(out, self[idx]) + return out + )=====") + std::string(R"=====(def max_dim(self: List[int], dim: int, @@ -2749,7 +2797,7 @@ const OperatorMap& GetShapeFunctionMappings() { {"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "expand"}, {"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "expand_one_unused"}, {"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"}, - {"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"}, + {"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "sum_dim"}, {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "max_dim"}, {"aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"}, {"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"}, diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index f5d38fa6061..70b9323b7e6 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1691,10 +1691,10 @@ REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator { }; } if (n->matches(torch::schema( - "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) { + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) { return [](ProcessedNode* p_node) { const at::Tensor& self = p_node->Input(0).toTensor(); - auto dim = p_node->Input(1).toIntList().vec(); + auto dim = p_node->Input(1).toDimVector(); auto keepdim = p_node->Input(2).toBool(); auto dtype = p_node->Input(3).toOptional(); if (p_node->Output(0).isNone()) { diff --git a/torch/csrc/jit/tensorexpr/lowerings.cpp b/torch/csrc/jit/tensorexpr/lowerings.cpp index 6d1681c4227..e6f56a655eb 100644 --- a/torch/csrc/jit/tensorexpr/lowerings.cpp +++ b/torch/csrc/jit/tensorexpr/lowerings.cpp @@ -1767,7 +1767,7 @@ int nnc_lowerings_lazy_registration() { RegisterNNCLoweringsFunction aten_sum( {"aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)", - "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"}, + "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"}, computeSum); RegisterNNCLoweringsFunction aten_softmax( diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index 89c1c40defc..c76c9e1823f 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -1004,7 +1004,7 @@ add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", add_shape_compute_mapping("aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand) add_shape_compute_mapping("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", expand_one_unused) add_shape_compute_mapping("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim) -add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim) +add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim) add_shape_compute_mapping("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", max_dim) add_shape_compute_mapping("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor) add_shape_compute_mapping("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a5acaff9480..e51671c82fe 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -18778,9 +18778,6 @@ op_db: List[OpInfo] = [ # FIXME: sum reduces all dimensions when dim=[] DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), - # FIXME: sum does not support passing None to dim - DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'), - DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'), # FIXME: improve precision DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', dtypes=[torch.float16]),