Reland: Enable dim=None for torch.sum (#79881)

Part of #29137

Reland of #75845
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79881
Approved by: https://github.com/albanD, https://github.com/kulinseth
This commit is contained in:
Kurt Mohler 2022-07-09 00:54:42 +00:00 committed by PyTorch MergeBot
parent 3f56a1b8c0
commit 23bdb570cf
25 changed files with 195 additions and 95 deletions

View file

@ -1 +1 @@
cc19c3abcbb3f702d5f468ee08549edd926ef549
2bdd718b4b7309b5868825e261ae05bef6be548f

View file

@ -56,18 +56,21 @@ static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
return dim == 0 || dim == -1;
}
Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, optional<ScalarType> dtype) {
// PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
// and instead returns a new scalar tensor (this also happens for dim=-1)
// If the following happens:
// >>> x = torch.randn(B0) # the per-examples are all scalars
// >>> vmap(partial(torch.sum, dim=0), x)
// then we replicate the behavior of sum(scalar_tensor, dim=0).
if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
return self.clone();
Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, optional<ScalarType> dtype) {
if (opt_dims.has_value()) {
auto dims = opt_dims.value();
// PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
// and instead returns a new scalar tensor (this also happens for dim=-1)
// If the following happens:
// >>> x = torch.randn(B0) # the per-examples are all scalars
// >>> vmap(partial(torch.sum, dim=0), x)
// then we replicate the behavior of sum(scalar_tensor, dim=0).
if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
return self.clone();
}
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dims_physical = self_physical.getPhysicalDims(dims);
auto dims_physical = self_physical.getPhysicalDims(opt_dims);
auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype);
return self_physical.getPhysicalToLogicalMap().apply(result);
}

View file

@ -55,13 +55,20 @@ int64_t VmapPhysicalView::numLogicalDims() const {
return /*physical*/tensor_.dim() - numBatchDims();
}
VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const {
VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const {
auto logical_ndim = numLogicalDims();
// NB: fmap doesn't have a SmallVector variant, so we don't use it here.
VmapDimVector result;
result.reserve(logical_ndim);
for (auto dim : logical_dims) {
result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
if (opt_logical_dims.has_value()) {
auto logical_dims = opt_logical_dims.value();
for (auto dim : logical_dims) {
result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
}
} else {
for (int64_t dim = 0; dim < logical_ndim; dim++) {
result.push_back(dim + numBatchDims());
}
}
return result;
}

View file

@ -131,7 +131,7 @@ struct TORCH_API VmapPhysicalView {
// This is because the size of levels tell us that the first two dimensions
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
// a physical dim of `n + 2`.
VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
int64_t getPhysicalDim(int64_t logical_dim) const;
// Returns a VmapPhysicalToLogicalMap object. This can be used for

View file

@ -14,7 +14,7 @@ namespace at {
constexpr size_t dim_bitset_size = 64;
static inline std::bitset<dim_bitset_size> dim_list_to_bitset(
IntArrayRef dims,
OptionalIntArrayRef opt_dims,
int64_t ndims) {
TORCH_CHECK(
ndims <= (int64_t)dim_bitset_size,
@ -22,11 +22,21 @@ static inline std::bitset<dim_bitset_size> dim_list_to_bitset(
dim_bitset_size,
" dims are supported");
std::bitset<dim_bitset_size> seen;
for (const auto i : c10::irange(dims.size())) {
size_t dim = maybe_wrap_dim(dims[i], ndims);
TORCH_CHECK(
!seen[dim], "dim ", dim, " appears multiple times in the list of dims");
seen[dim] = true;
if (opt_dims.has_value()) {
auto dims = opt_dims.value();
for (const auto i : c10::irange(dims.size())) {
size_t dim = maybe_wrap_dim(dims[i], ndims);
TORCH_CHECK(
!seen[dim],
"dim ",
dim,
" appears multiple times in the list of dims");
seen[dim] = true;
}
} else {
for (int64_t dim = 0; dim < ndims; dim++) {
seen[dim] = true;
}
}
return seen;
}

View file

@ -455,7 +455,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
// KERNEL(ADD_NS(norm), "norm.ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, IntArrayRef, bool, ScalarType), fp32_set_opt_dtype)
// KERNEL(ADD_NS(norm), "norm.names_ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, DimnameList, bool, ScalarType), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum", Tensor (const Tensor &, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, OptionalIntArrayRef, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
// fp32_append_dtype
// The fp32_append_dtype wrapper overrides implicit promotion behavior.

View file

@ -52,8 +52,6 @@ namespace meta {
static ScalarType infer_dtype_from_optional(
const Tensor& self,
IntArrayRef dim,
bool keepdim,
const optional<ScalarType>& opt_dtype,
const Tensor& result) {
// 'opt_dtype' has the priority for both cases.
@ -187,9 +185,9 @@ TORCH_META_FUNC(cumprod)
}
TORCH_META_FUNC2(sum, dim_IntList)
(const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
(const Tensor& self, OptionalIntArrayRef opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output());
resize_reduction(*this, self, opt_dim, keepdim, out_dtype);
}
TORCH_META_FUNC2(prod, dim_int)
@ -197,7 +195,7 @@ TORCH_META_FUNC2(prod, dim_int)
int64_t dim,
bool keepdim,
c10::optional<ScalarType> dtype) {
auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, dtype, maybe_get_output());
auto out_dtype = infer_dtype_from_optional(self, dtype, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
}
@ -221,7 +219,7 @@ TORCH_META_FUNC2(mean, dim)
"Got: ", dtype);
}
auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output());
auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype);
}
@ -1061,11 +1059,11 @@ inline ScalarType get_dtype_from_result(Tensor& result, optional<ScalarType> dty
TORCH_IMPL_FUNC(sum_out)
(const Tensor& self,
IntArrayRef dim,
OptionalIntArrayRef opt_dim,
bool keepdim,
optional<ScalarType> opt_dtype,
const Tensor& result) {
auto iter = meta::make_reduction_from_out_ty(self, result, dim, keepdim, result.scalar_type());
auto iter = meta::make_reduction_from_out_ty(self, result, opt_dim, keepdim, result.scalar_type());
if (iter.numel() == 0) {
result.zero_();
} else {

View file

@ -110,12 +110,27 @@ static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dty
using DimMask = TensorIterator::DimMask;
static DimMask make_dim_mask(IntArrayRef dims, int64_t ndim) {
DimMask mask;
if (dims.empty()) {
mask = DimMask().flip();
static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
if (opt_dims.has_value()) {
return DimVector(opt_dims.value());
} else {
mask = at::dim_list_to_bitset(dims, ndim);
std::vector<int64_t> all_dims(ndim);
std::iota(all_dims.begin(), all_dims.end(), 0);
return DimVector(all_dims);
}
}
static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim) {
DimMask mask;
if (opt_dims.has_value()) {
auto dims = opt_dims.value();
if (dims.empty()) {
mask = DimMask().flip();
} else {
mask = at::dim_list_to_bitset(dims, ndim);
}
} else {
mask = DimMask().flip();
}
return mask;
}
@ -320,10 +335,10 @@ static C10_UNUSED DimVector get_reduction_shape(
static void resize_reduction(
impl::MetaBase& meta,
const Tensor& self,
IntArrayRef dims,
OptionalIntArrayRef opt_dims,
bool keepdim,
ScalarType out_dtype) {
DimVector dims_(dims);
DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
maybe_wrap_dims(dims_, self.dim());
auto shape = get_reduction_shape(self, dims_, keepdim);
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
@ -351,11 +366,11 @@ static void resize_reduction_with_indices(
static TensorIterator make_reduction(
const Tensor& self,
const Tensor& result,
IntArrayRef dims,
OptionalIntArrayRef opt_dims,
bool keepdim,
ScalarType in_dtype) {
int64_t ndim = self.dim();
auto mask = at::native::make_dim_mask(dims, ndim);
auto mask = at::native::make_dim_mask(opt_dims, ndim);
auto viewed_result =
at::native::review_reduce_result(result, ndim, mask, keepdim);
if (self.scalar_type() == in_dtype) {
@ -389,7 +404,7 @@ static TensorIterator make_reduction(
static C10_UNUSED TensorIterator make_reduction_from_out_ty(
const Tensor& self,
const Tensor& result,
IntArrayRef dims,
OptionalIntArrayRef opt_dims,
bool keepdim,
ScalarType out_dtype) {
// special case for type promotion in mixed precision, improves computational
@ -401,7 +416,7 @@ static C10_UNUSED TensorIterator make_reduction_from_out_ty(
(self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
out_dtype == kFloat);
auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
return make_reduction(self, result, dims, keepdim, in_dtype);
return make_reduction(self, result, opt_dims, keepdim, in_dtype);
}
} // namespace meta

View file

@ -84,13 +84,15 @@ void set_apparent_shapes(NSMutableArray<NSNumber*> * &apparent_out_shape,
// Helper function to set the axes of reduction
void set_axes(NSMutableArray<NSNumber *> * &axes,
int64_t num_reduce_dims,
IntArrayRef& dim,
OptionalIntArrayRef opt_dim,
int64_t num_input_dims) {
if(num_reduce_dims == 0) {
axes = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
axes[0] = @0;
}
else {
TORCH_INTERNAL_ASSERT(opt_dim.has_value());
IntArrayRef dim = opt_dim.value();
axes = [NSMutableArray<NSNumber*> arrayWithCapacity:num_reduce_dims];
for(int i = 0; i < num_reduce_dims; i++) {
axes[i] = [NSNumber numberWithInt:maybe_wrap_dim(dim[i], num_input_dims)];
@ -100,7 +102,7 @@ void set_axes(NSMutableArray<NSNumber *> * &axes,
// Helper function to prepare axes and tensor shapes
void set_axes_and_shapes(const Tensor& input_t,
IntArrayRef dims,
OptionalIntArrayRef opt_dims,
NSMutableArray<NSNumber*> * &axes,
NSMutableArray<NSNumber*> * &apparent_input_shape,
NSMutableArray<NSNumber*> * &apparent_output_shape,
@ -109,13 +111,13 @@ void set_axes_and_shapes(const Tensor& input_t,
IntArrayRef input_shape = input_t.sizes();
int64_t num_input_dims = input_shape.size();
int64_t num_reduce_dims = dims.size();
int64_t num_reduce_dims = opt_dims.has_value() ? opt_dims.value().size() : 0;
int64_t num_output_dims;
num_output_dims = num_reduce_dims == 0 ? 1 : num_input_dims;
// Reduction axes
set_axes(axes, num_reduce_dims, dims, input_shape.size());
set_axes(axes, num_reduce_dims, opt_dims, input_shape.size());
// Shapes
set_apparent_shapes(apparent_output_shape,
@ -137,7 +139,7 @@ void set_axes_and_shapes(const Tensor& input_t,
void reduction_out_mps
(const Tensor& input_t,
IntArrayRef dim,
OptionalIntArrayRef opt_dim,
bool keepdim,
c10::optional<ScalarType> dtype,
const Tensor& output_t,
@ -146,10 +148,13 @@ void reduction_out_mps
IntArrayRef input_shape = input_t.sizes();
for(int i = 0; i < dim.size(); i++) {
auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size());
TORCH_CHECK(wrap_dim < input_shape.size(),
func_name+": reduction dim must be in the range of input shape")
if (opt_dim.has_value()) {
IntArrayRef dim = opt_dim.value();
for(int i = 0; i < dim.size(); i++) {
auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size());
TORCH_CHECK(wrap_dim < input_shape.size(),
func_name+": reduction dim must be in the range of input shape")
}
}
namespace native_mps = at::native::mps;
@ -159,7 +164,7 @@ void reduction_out_mps
NSMutableArray<NSNumber*> *apparent_output_shape = nil;
NSMutableArray<NSNumber*> *output_shape = nil;
set_axes_and_shapes(input_t, dim, axes, apparent_input_shape, apparent_output_shape, output_shape);
set_axes_and_shapes(input_t, opt_dim, axes, apparent_input_shape, apparent_output_shape, output_shape);
auto cache_ = native_mps::MPSGraphCache::getInstance();
@ -271,12 +276,12 @@ void reduction_out_mps
TORCH_IMPL_FUNC(sum_out_mps)
(const Tensor& input_t,
IntArrayRef dim,
OptionalIntArrayRef opt_dim,
bool keepdim,
c10::optional<ScalarType> dtype,
const Tensor& output_t) {
reduction_out_mps(input_t, dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps");
reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps");
}
TORCH_IMPL_FUNC(prod_out_mps)

View file

@ -4563,7 +4563,7 @@
CompositeExplicitAutograd: sum
SparseCsrCPU, SparseCsrCUDA: sum_csr
- func: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
structured_delegate: sum.IntList_out
device_check: NoCheck # TensorIterator
variants: function, method
@ -4572,7 +4572,7 @@
device_check: NoCheck # TensorIterator
variants: function, method
- func: sum.IntList_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
structured: True
device_check: NoCheck # TensorIterator
dispatch:

View file

@ -74,6 +74,9 @@ class OptionalArrayRef final {
Args&&... args)
: wrapped_opt_array_ref(ip, il, args...) {}
constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
: wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
// Destructor
~OptionalArrayRef() = default;

View file

@ -5368,7 +5368,7 @@ a")
def func2(x):
return x.sum(dim=4)
# test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
# test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument
self.run_pass('constant_propagation', func.graph)
self.run_pass('constant_propagation', func2.graph)
g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)

View file

@ -1195,8 +1195,11 @@ class TestNamedTensor(TestCase):
check_output(op(t, 1), ['N', 'L'])
check_output(op(t, -1), ['N', 'C'])
check_output(op(t, 'C'), ['N', 'L'])
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
op(t, None)
if op.__name__ in ['sum']:
check_output(op(t, None), [])
else:
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
op(t, None)
with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'):
op(t, 'H')

View file

@ -1522,7 +1522,7 @@
self: grad.expand(self.sizes())
result: auto_linear
- name: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: sum_backward(grad, self.sizes(), dim, keepdim)
result: auto_linear

View file

@ -50,6 +50,9 @@ If :attr:`keepdim` is ``True``, the output tensor is of the same size
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
output tensor having 1 (or ``len(dim)``) fewer dimension(s).
"""}, {'opt_dim': """
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
If ``None``, all dimensions are reduced.
"""})
single_dim_common = merge_dicts(reduceops_common_args, parse_kwargs("""
@ -9651,7 +9654,7 @@ reduce over all of them.
Args:
{input}
{dim}
{opt_dim}
{keepdim}
Keyword args:

View file

@ -20,6 +20,7 @@
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/OptionalArrayRef.h>
#include <c10/util/SmallBuffer.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
@ -38,6 +39,7 @@ namespace details {
using at::areAnyTensorSubclassLike;
using at::IntArrayRef;
using at::OptionalIntArrayRef;
using at::Scalar;
using at::Tensor;
using at::TensorList;
@ -556,35 +558,41 @@ Tensor deg2rad_backward(const Tensor& grad) {
return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_PI_180)));
}
Tensor unsqueeze_multiple(const Tensor& t, IntArrayRef dim, size_t n_dims) {
auto dim_size = dim.size();
// Optimisation for two common cases
if (dim_size == 0) {
return t;
} else if (dim_size == 0) {
return t.unsqueeze(dim[0]);
} else {
auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
Tensor res = t;
for (const auto i : c10::irange(n_dims)) {
if (dims_to_unsqueeze[i]) {
res = res.unsqueeze(i);
}
Tensor unsqueeze_multiple(
const Tensor& t,
OptionalIntArrayRef opt_dim,
size_t n_dims) {
if (opt_dim.has_value()) {
IntArrayRef dim = opt_dim.value();
auto dim_size = dim.size();
// Optimisation for two common cases
if (dim_size == 0) {
return t;
} else if (dim_size == 1) {
return t.unsqueeze(dim[0]);
}
return res;
}
auto dims_to_unsqueeze = at::dim_list_to_bitset(opt_dim, n_dims);
Tensor res = t;
for (const auto i : c10::irange(n_dims)) {
if (dims_to_unsqueeze[i]) {
res = res.unsqueeze(i);
}
}
return res;
}
Tensor sum_backward(
const Tensor& grad,
IntArrayRef sizes,
IntArrayRef dims,
OptionalIntArrayRef opt_dims,
bool keepdim) {
if (!keepdim && sizes.size() > 0 && dims.size() > 0) {
return unsqueeze_multiple(grad, dims, sizes.size()).expand(sizes);
} else {
return grad.expand(sizes);
if (!keepdim && sizes.size() > 0) {
if (opt_dims.has_value() && opt_dims.value().size() > 0) {
return unsqueeze_multiple(grad, opt_dims, sizes.size()).expand(sizes);
}
}
return grad.expand(sizes);
}
Tensor nansum_backward(

View file

@ -150,12 +150,12 @@ at::Tensor rad2deg_backward(const at::Tensor& grad);
at::Tensor deg2rad_backward(const at::Tensor& grad);
at::Tensor unsqueeze_multiple(
const at::Tensor& t,
at::IntArrayRef dim,
at::OptionalIntArrayRef opt_dim,
size_t n_dims);
at::Tensor sum_backward(
const at::Tensor& grad,
at::IntArrayRef sizes,
at::IntArrayRef dims,
at::OptionalIntArrayRef opt_dims,
bool keepdim);
at::Tensor nansum_backward(
const at::Tensor& grad,

View file

@ -2480,7 +2480,7 @@ class IrParser {
{
auto ptr_op = getOperatorForLiteral(
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
REGISTER_PARSE_RULE(
ptr_op,
{
@ -3857,7 +3857,7 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
static auto reduction_operator_schema =
getOperatorForLiteral(
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)")
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)")
->schema();
if (node->matches(reduction_operator_schema)) {
switch (offset) {

View file

@ -1980,7 +1980,7 @@ class ShapePropagator : public PropertyPropBase {
return true;
} else if (
node->matches(
"aten::sum(Tensor self, int[] dim, bool keepdim, *, int? dtype) -> Tensor",
"aten::sum(Tensor self, int[]? dim, bool keepdim, *, int? dtype) -> Tensor",
/*const_inputs=*/{attr::dim, attr::keepdim})) {
auto& tp = tensor_types.at(0);
auto sizes = tp->sizes().concrete_sizes().value();

View file

@ -94,7 +94,7 @@ bool isSupported(Node* node) {
static const OperatorSet supported_reduction_set{
"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor",
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
"aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor",
"aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
};

View file

@ -2158,6 +2158,54 @@ def transpose(self: List[int],
_4 = torch.append(out, self[idx])
return out
)=====")
+ std::string(R"=====(def sum_dim(self: List[int],
opt_dims: Optional[List[int]],
keep_dim: bool,
dt: Any) -> List[int]:
out = annotate(List[int], [])
if opt_dims is None:
dims:List[int] = []
else:
dims = opt_dims
for idx in range(torch.len(self)):
is_mean_dim = False
for _0 in range(torch.len(dims)):
reduce_dim = dims[_0]
_1 = torch.len(self)
if torch.le(_1, 0):
dim_post_expr = 1
else:
dim_post_expr = _1
min = torch.neg(dim_post_expr)
max = torch.sub(dim_post_expr, 1)
if torch.lt(reduce_dim, min):
_2 = True
else:
_2 = torch.gt(reduce_dim, max)
if torch.__not__(_2):
pass
else:
ops.prim.RaiseException("AssertionError: ")
if torch.lt(reduce_dim, 0):
dim0 = torch.add(reduce_dim, dim_post_expr)
dim = dim0
else:
dim = reduce_dim
if torch.eq(idx, dim):
is_mean_dim0 = True
else:
is_mean_dim0 = is_mean_dim
is_mean_dim = is_mean_dim0
if is_mean_dim:
if keep_dim:
_3 = torch.append(out, 1)
else:
pass
else:
_4 = torch.append(out, self[idx])
return out
)=====")
+ std::string(R"=====(def max_dim(self: List[int],
dim: int,
@ -2749,7 +2797,7 @@ const OperatorMap<std::string>& GetShapeFunctionMappings() {
{"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "expand"},
{"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "expand_one_unused"},
{"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"},
{"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"},
{"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "sum_dim"},
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "max_dim"},
{"aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},
{"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},

View file

@ -1691,10 +1691,10 @@ REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
};
}
if (n->matches(torch::schema(
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
return [](ProcessedNode* p_node) {
const at::Tensor& self = p_node->Input(0).toTensor();
auto dim = p_node->Input(1).toIntList().vec();
auto dim = p_node->Input(1).toDimVector();
auto keepdim = p_node->Input(2).toBool();
auto dtype = p_node->Input(3).toOptional<at::ScalarType>();
if (p_node->Output(0).isNone()) {

View file

@ -1767,7 +1767,7 @@ int nnc_lowerings_lazy_registration() {
RegisterNNCLoweringsFunction aten_sum(
{"aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)",
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"},
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"},
computeSum);
RegisterNNCLoweringsFunction aten_softmax(

View file

@ -1004,7 +1004,7 @@ add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)",
add_shape_compute_mapping("aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand)
add_shape_compute_mapping("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", expand_one_unused)
add_shape_compute_mapping("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
add_shape_compute_mapping("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", max_dim)
add_shape_compute_mapping("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)
add_shape_compute_mapping("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)

View file

@ -18778,9 +18778,6 @@ op_db: List[OpInfo] = [
# FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
# FIXME: sum does not support passing None to dim
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
# FIXME: improve precision
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
dtypes=[torch.float16]),