mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Reland: Enable dim=None for torch.sum (#79881)
Part of #29137 Reland of #75845 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79881 Approved by: https://github.com/albanD, https://github.com/kulinseth
This commit is contained in:
parent
3f56a1b8c0
commit
23bdb570cf
25 changed files with 195 additions and 95 deletions
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
|
|
@ -1 +1 @@
|
|||
cc19c3abcbb3f702d5f468ee08549edd926ef549
|
||||
2bdd718b4b7309b5868825e261ae05bef6be548f
|
||||
|
|
|
|||
|
|
@ -56,18 +56,21 @@ static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
|
|||
return dim == 0 || dim == -1;
|
||||
}
|
||||
|
||||
Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, optional<ScalarType> dtype) {
|
||||
// PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
|
||||
// and instead returns a new scalar tensor (this also happens for dim=-1)
|
||||
// If the following happens:
|
||||
// >>> x = torch.randn(B0) # the per-examples are all scalars
|
||||
// >>> vmap(partial(torch.sum, dim=0), x)
|
||||
// then we replicate the behavior of sum(scalar_tensor, dim=0).
|
||||
if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
|
||||
return self.clone();
|
||||
Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, optional<ScalarType> dtype) {
|
||||
if (opt_dims.has_value()) {
|
||||
auto dims = opt_dims.value();
|
||||
// PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
|
||||
// and instead returns a new scalar tensor (this also happens for dim=-1)
|
||||
// If the following happens:
|
||||
// >>> x = torch.randn(B0) # the per-examples are all scalars
|
||||
// >>> vmap(partial(torch.sum, dim=0), x)
|
||||
// then we replicate the behavior of sum(scalar_tensor, dim=0).
|
||||
if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
|
||||
return self.clone();
|
||||
}
|
||||
}
|
||||
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
||||
auto dims_physical = self_physical.getPhysicalDims(dims);
|
||||
auto dims_physical = self_physical.getPhysicalDims(opt_dims);
|
||||
auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype);
|
||||
return self_physical.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,13 +55,20 @@ int64_t VmapPhysicalView::numLogicalDims() const {
|
|||
return /*physical*/tensor_.dim() - numBatchDims();
|
||||
}
|
||||
|
||||
VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const {
|
||||
VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const {
|
||||
auto logical_ndim = numLogicalDims();
|
||||
// NB: fmap doesn't have a SmallVector variant, so we don't use it here.
|
||||
VmapDimVector result;
|
||||
result.reserve(logical_ndim);
|
||||
for (auto dim : logical_dims) {
|
||||
result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
|
||||
if (opt_logical_dims.has_value()) {
|
||||
auto logical_dims = opt_logical_dims.value();
|
||||
for (auto dim : logical_dims) {
|
||||
result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
|
||||
}
|
||||
} else {
|
||||
for (int64_t dim = 0; dim < logical_ndim; dim++) {
|
||||
result.push_back(dim + numBatchDims());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ struct TORCH_API VmapPhysicalView {
|
|||
// This is because the size of levels tell us that the first two dimensions
|
||||
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
|
||||
// a physical dim of `n + 2`.
|
||||
VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
|
||||
VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
|
||||
int64_t getPhysicalDim(int64_t logical_dim) const;
|
||||
|
||||
// Returns a VmapPhysicalToLogicalMap object. This can be used for
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ namespace at {
|
|||
constexpr size_t dim_bitset_size = 64;
|
||||
|
||||
static inline std::bitset<dim_bitset_size> dim_list_to_bitset(
|
||||
IntArrayRef dims,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
int64_t ndims) {
|
||||
TORCH_CHECK(
|
||||
ndims <= (int64_t)dim_bitset_size,
|
||||
|
|
@ -22,11 +22,21 @@ static inline std::bitset<dim_bitset_size> dim_list_to_bitset(
|
|||
dim_bitset_size,
|
||||
" dims are supported");
|
||||
std::bitset<dim_bitset_size> seen;
|
||||
for (const auto i : c10::irange(dims.size())) {
|
||||
size_t dim = maybe_wrap_dim(dims[i], ndims);
|
||||
TORCH_CHECK(
|
||||
!seen[dim], "dim ", dim, " appears multiple times in the list of dims");
|
||||
seen[dim] = true;
|
||||
if (opt_dims.has_value()) {
|
||||
auto dims = opt_dims.value();
|
||||
for (const auto i : c10::irange(dims.size())) {
|
||||
size_t dim = maybe_wrap_dim(dims[i], ndims);
|
||||
TORCH_CHECK(
|
||||
!seen[dim],
|
||||
"dim ",
|
||||
dim,
|
||||
" appears multiple times in the list of dims");
|
||||
seen[dim] = true;
|
||||
}
|
||||
} else {
|
||||
for (int64_t dim = 0; dim < ndims; dim++) {
|
||||
seen[dim] = true;
|
||||
}
|
||||
}
|
||||
return seen;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -455,7 +455,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
|||
// KERNEL(ADD_NS(norm), "norm.ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, IntArrayRef, bool, ScalarType), fp32_set_opt_dtype)
|
||||
// KERNEL(ADD_NS(norm), "norm.names_ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, DimnameList, bool, ScalarType), fp32_set_opt_dtype)
|
||||
KERNEL(ADD_NS(sum), "sum", Tensor (const Tensor &, c10::optional<ScalarType>), fp32_set_opt_dtype)
|
||||
KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
|
||||
KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, OptionalIntArrayRef, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
|
||||
KERNEL(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
|
||||
// fp32_append_dtype
|
||||
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
|
||||
|
|
|
|||
|
|
@ -52,8 +52,6 @@ namespace meta {
|
|||
|
||||
static ScalarType infer_dtype_from_optional(
|
||||
const Tensor& self,
|
||||
IntArrayRef dim,
|
||||
bool keepdim,
|
||||
const optional<ScalarType>& opt_dtype,
|
||||
const Tensor& result) {
|
||||
// 'opt_dtype' has the priority for both cases.
|
||||
|
|
@ -187,9 +185,9 @@ TORCH_META_FUNC(cumprod)
|
|||
}
|
||||
|
||||
TORCH_META_FUNC2(sum, dim_IntList)
|
||||
(const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
|
||||
auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output());
|
||||
resize_reduction(*this, self, dim, keepdim, out_dtype);
|
||||
(const Tensor& self, OptionalIntArrayRef opt_dim, bool keepdim, optional<ScalarType> opt_dtype) {
|
||||
auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output());
|
||||
resize_reduction(*this, self, opt_dim, keepdim, out_dtype);
|
||||
}
|
||||
|
||||
TORCH_META_FUNC2(prod, dim_int)
|
||||
|
|
@ -197,7 +195,7 @@ TORCH_META_FUNC2(prod, dim_int)
|
|||
int64_t dim,
|
||||
bool keepdim,
|
||||
c10::optional<ScalarType> dtype) {
|
||||
auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, dtype, maybe_get_output());
|
||||
auto out_dtype = infer_dtype_from_optional(self, dtype, maybe_get_output());
|
||||
resize_reduction(*this, self, dim, keepdim, out_dtype);
|
||||
}
|
||||
|
||||
|
|
@ -221,7 +219,7 @@ TORCH_META_FUNC2(mean, dim)
|
|||
"Got: ", dtype);
|
||||
}
|
||||
|
||||
auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output());
|
||||
auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output());
|
||||
resize_reduction(*this, self, dim, keepdim, out_dtype);
|
||||
}
|
||||
|
||||
|
|
@ -1061,11 +1059,11 @@ inline ScalarType get_dtype_from_result(Tensor& result, optional<ScalarType> dty
|
|||
|
||||
TORCH_IMPL_FUNC(sum_out)
|
||||
(const Tensor& self,
|
||||
IntArrayRef dim,
|
||||
OptionalIntArrayRef opt_dim,
|
||||
bool keepdim,
|
||||
optional<ScalarType> opt_dtype,
|
||||
const Tensor& result) {
|
||||
auto iter = meta::make_reduction_from_out_ty(self, result, dim, keepdim, result.scalar_type());
|
||||
auto iter = meta::make_reduction_from_out_ty(self, result, opt_dim, keepdim, result.scalar_type());
|
||||
if (iter.numel() == 0) {
|
||||
result.zero_();
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -110,12 +110,27 @@ static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dty
|
|||
|
||||
using DimMask = TensorIterator::DimMask;
|
||||
|
||||
static DimMask make_dim_mask(IntArrayRef dims, int64_t ndim) {
|
||||
DimMask mask;
|
||||
if (dims.empty()) {
|
||||
mask = DimMask().flip();
|
||||
static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) {
|
||||
if (opt_dims.has_value()) {
|
||||
return DimVector(opt_dims.value());
|
||||
} else {
|
||||
mask = at::dim_list_to_bitset(dims, ndim);
|
||||
std::vector<int64_t> all_dims(ndim);
|
||||
std::iota(all_dims.begin(), all_dims.end(), 0);
|
||||
return DimVector(all_dims);
|
||||
}
|
||||
}
|
||||
|
||||
static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim) {
|
||||
DimMask mask;
|
||||
if (opt_dims.has_value()) {
|
||||
auto dims = opt_dims.value();
|
||||
if (dims.empty()) {
|
||||
mask = DimMask().flip();
|
||||
} else {
|
||||
mask = at::dim_list_to_bitset(dims, ndim);
|
||||
}
|
||||
} else {
|
||||
mask = DimMask().flip();
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
|
@ -320,10 +335,10 @@ static C10_UNUSED DimVector get_reduction_shape(
|
|||
static void resize_reduction(
|
||||
impl::MetaBase& meta,
|
||||
const Tensor& self,
|
||||
IntArrayRef dims,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
bool keepdim,
|
||||
ScalarType out_dtype) {
|
||||
DimVector dims_(dims);
|
||||
DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
|
||||
maybe_wrap_dims(dims_, self.dim());
|
||||
auto shape = get_reduction_shape(self, dims_, keepdim);
|
||||
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
|
||||
|
|
@ -351,11 +366,11 @@ static void resize_reduction_with_indices(
|
|||
static TensorIterator make_reduction(
|
||||
const Tensor& self,
|
||||
const Tensor& result,
|
||||
IntArrayRef dims,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
bool keepdim,
|
||||
ScalarType in_dtype) {
|
||||
int64_t ndim = self.dim();
|
||||
auto mask = at::native::make_dim_mask(dims, ndim);
|
||||
auto mask = at::native::make_dim_mask(opt_dims, ndim);
|
||||
auto viewed_result =
|
||||
at::native::review_reduce_result(result, ndim, mask, keepdim);
|
||||
if (self.scalar_type() == in_dtype) {
|
||||
|
|
@ -389,7 +404,7 @@ static TensorIterator make_reduction(
|
|||
static C10_UNUSED TensorIterator make_reduction_from_out_ty(
|
||||
const Tensor& self,
|
||||
const Tensor& result,
|
||||
IntArrayRef dims,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
bool keepdim,
|
||||
ScalarType out_dtype) {
|
||||
// special case for type promotion in mixed precision, improves computational
|
||||
|
|
@ -401,7 +416,7 @@ static C10_UNUSED TensorIterator make_reduction_from_out_ty(
|
|||
(self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
|
||||
out_dtype == kFloat);
|
||||
auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
|
||||
return make_reduction(self, result, dims, keepdim, in_dtype);
|
||||
return make_reduction(self, result, opt_dims, keepdim, in_dtype);
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
|
|
|
|||
|
|
@ -84,13 +84,15 @@ void set_apparent_shapes(NSMutableArray<NSNumber*> * &apparent_out_shape,
|
|||
// Helper function to set the axes of reduction
|
||||
void set_axes(NSMutableArray<NSNumber *> * &axes,
|
||||
int64_t num_reduce_dims,
|
||||
IntArrayRef& dim,
|
||||
OptionalIntArrayRef opt_dim,
|
||||
int64_t num_input_dims) {
|
||||
if(num_reduce_dims == 0) {
|
||||
axes = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
|
||||
axes[0] = @0;
|
||||
}
|
||||
else {
|
||||
TORCH_INTERNAL_ASSERT(opt_dim.has_value());
|
||||
IntArrayRef dim = opt_dim.value();
|
||||
axes = [NSMutableArray<NSNumber*> arrayWithCapacity:num_reduce_dims];
|
||||
for(int i = 0; i < num_reduce_dims; i++) {
|
||||
axes[i] = [NSNumber numberWithInt:maybe_wrap_dim(dim[i], num_input_dims)];
|
||||
|
|
@ -100,7 +102,7 @@ void set_axes(NSMutableArray<NSNumber *> * &axes,
|
|||
|
||||
// Helper function to prepare axes and tensor shapes
|
||||
void set_axes_and_shapes(const Tensor& input_t,
|
||||
IntArrayRef dims,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
NSMutableArray<NSNumber*> * &axes,
|
||||
NSMutableArray<NSNumber*> * &apparent_input_shape,
|
||||
NSMutableArray<NSNumber*> * &apparent_output_shape,
|
||||
|
|
@ -109,13 +111,13 @@ void set_axes_and_shapes(const Tensor& input_t,
|
|||
IntArrayRef input_shape = input_t.sizes();
|
||||
|
||||
int64_t num_input_dims = input_shape.size();
|
||||
int64_t num_reduce_dims = dims.size();
|
||||
int64_t num_reduce_dims = opt_dims.has_value() ? opt_dims.value().size() : 0;
|
||||
int64_t num_output_dims;
|
||||
|
||||
num_output_dims = num_reduce_dims == 0 ? 1 : num_input_dims;
|
||||
|
||||
// Reduction axes
|
||||
set_axes(axes, num_reduce_dims, dims, input_shape.size());
|
||||
set_axes(axes, num_reduce_dims, opt_dims, input_shape.size());
|
||||
|
||||
// Shapes
|
||||
set_apparent_shapes(apparent_output_shape,
|
||||
|
|
@ -137,7 +139,7 @@ void set_axes_and_shapes(const Tensor& input_t,
|
|||
|
||||
void reduction_out_mps
|
||||
(const Tensor& input_t,
|
||||
IntArrayRef dim,
|
||||
OptionalIntArrayRef opt_dim,
|
||||
bool keepdim,
|
||||
c10::optional<ScalarType> dtype,
|
||||
const Tensor& output_t,
|
||||
|
|
@ -146,10 +148,13 @@ void reduction_out_mps
|
|||
|
||||
IntArrayRef input_shape = input_t.sizes();
|
||||
|
||||
for(int i = 0; i < dim.size(); i++) {
|
||||
auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size());
|
||||
TORCH_CHECK(wrap_dim < input_shape.size(),
|
||||
func_name+": reduction dim must be in the range of input shape")
|
||||
if (opt_dim.has_value()) {
|
||||
IntArrayRef dim = opt_dim.value();
|
||||
for(int i = 0; i < dim.size(); i++) {
|
||||
auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size());
|
||||
TORCH_CHECK(wrap_dim < input_shape.size(),
|
||||
func_name+": reduction dim must be in the range of input shape")
|
||||
}
|
||||
}
|
||||
|
||||
namespace native_mps = at::native::mps;
|
||||
|
|
@ -159,7 +164,7 @@ void reduction_out_mps
|
|||
NSMutableArray<NSNumber*> *apparent_output_shape = nil;
|
||||
NSMutableArray<NSNumber*> *output_shape = nil;
|
||||
|
||||
set_axes_and_shapes(input_t, dim, axes, apparent_input_shape, apparent_output_shape, output_shape);
|
||||
set_axes_and_shapes(input_t, opt_dim, axes, apparent_input_shape, apparent_output_shape, output_shape);
|
||||
|
||||
auto cache_ = native_mps::MPSGraphCache::getInstance();
|
||||
|
||||
|
|
@ -271,12 +276,12 @@ void reduction_out_mps
|
|||
|
||||
TORCH_IMPL_FUNC(sum_out_mps)
|
||||
(const Tensor& input_t,
|
||||
IntArrayRef dim,
|
||||
OptionalIntArrayRef opt_dim,
|
||||
bool keepdim,
|
||||
c10::optional<ScalarType> dtype,
|
||||
const Tensor& output_t) {
|
||||
|
||||
reduction_out_mps(input_t, dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps");
|
||||
reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(prod_out_mps)
|
||||
|
|
|
|||
|
|
@ -4563,7 +4563,7 @@
|
|||
CompositeExplicitAutograd: sum
|
||||
SparseCsrCPU, SparseCsrCUDA: sum_csr
|
||||
|
||||
- func: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
||||
- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
||||
structured_delegate: sum.IntList_out
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
|
|
@ -4572,7 +4572,7 @@
|
|||
device_check: NoCheck # TensorIterator
|
||||
variants: function, method
|
||||
|
||||
- func: sum.IntList_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
|
||||
- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
|
||||
structured: True
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -74,6 +74,9 @@ class OptionalArrayRef final {
|
|||
Args&&... args)
|
||||
: wrapped_opt_array_ref(ip, il, args...) {}
|
||||
|
||||
constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
|
||||
: wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
|
||||
|
||||
// Destructor
|
||||
|
||||
~OptionalArrayRef() = default;
|
||||
|
|
|
|||
|
|
@ -5368,7 +5368,7 @@ a")
|
|||
def func2(x):
|
||||
return x.sum(dim=4)
|
||||
|
||||
# test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
|
||||
# test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument
|
||||
self.run_pass('constant_propagation', func.graph)
|
||||
self.run_pass('constant_propagation', func2.graph)
|
||||
g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
|
||||
|
|
|
|||
|
|
@ -1195,8 +1195,11 @@ class TestNamedTensor(TestCase):
|
|||
check_output(op(t, 1), ['N', 'L'])
|
||||
check_output(op(t, -1), ['N', 'C'])
|
||||
check_output(op(t, 'C'), ['N', 'L'])
|
||||
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
|
||||
op(t, None)
|
||||
if op.__name__ in ['sum']:
|
||||
check_output(op(t, None), [])
|
||||
else:
|
||||
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
|
||||
op(t, None)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'):
|
||||
op(t, 'H')
|
||||
|
||||
|
|
|
|||
|
|
@ -1522,7 +1522,7 @@
|
|||
self: grad.expand(self.sizes())
|
||||
result: auto_linear
|
||||
|
||||
- name: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
||||
- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
|
||||
self: sum_backward(grad, self.sizes(), dim, keepdim)
|
||||
result: auto_linear
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,9 @@ If :attr:`keepdim` is ``True``, the output tensor is of the same size
|
|||
as :attr:`input` except in the dimension(s) :attr:`dim` where it is of size 1.
|
||||
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting in the
|
||||
output tensor having 1 (or ``len(dim)``) fewer dimension(s).
|
||||
"""}, {'opt_dim': """
|
||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||
If ``None``, all dimensions are reduced.
|
||||
"""})
|
||||
|
||||
single_dim_common = merge_dicts(reduceops_common_args, parse_kwargs("""
|
||||
|
|
@ -9651,7 +9654,7 @@ reduce over all of them.
|
|||
|
||||
Args:
|
||||
{input}
|
||||
{dim}
|
||||
{opt_dim}
|
||||
{keepdim}
|
||||
|
||||
Keyword args:
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/OptionalArrayRef.h>
|
||||
#include <c10/util/SmallBuffer.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
|
@ -38,6 +39,7 @@ namespace details {
|
|||
|
||||
using at::areAnyTensorSubclassLike;
|
||||
using at::IntArrayRef;
|
||||
using at::OptionalIntArrayRef;
|
||||
using at::Scalar;
|
||||
using at::Tensor;
|
||||
using at::TensorList;
|
||||
|
|
@ -556,35 +558,41 @@ Tensor deg2rad_backward(const Tensor& grad) {
|
|||
return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_PI_180)));
|
||||
}
|
||||
|
||||
Tensor unsqueeze_multiple(const Tensor& t, IntArrayRef dim, size_t n_dims) {
|
||||
auto dim_size = dim.size();
|
||||
// Optimisation for two common cases
|
||||
if (dim_size == 0) {
|
||||
return t;
|
||||
} else if (dim_size == 0) {
|
||||
return t.unsqueeze(dim[0]);
|
||||
} else {
|
||||
auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
|
||||
Tensor res = t;
|
||||
for (const auto i : c10::irange(n_dims)) {
|
||||
if (dims_to_unsqueeze[i]) {
|
||||
res = res.unsqueeze(i);
|
||||
}
|
||||
Tensor unsqueeze_multiple(
|
||||
const Tensor& t,
|
||||
OptionalIntArrayRef opt_dim,
|
||||
size_t n_dims) {
|
||||
if (opt_dim.has_value()) {
|
||||
IntArrayRef dim = opt_dim.value();
|
||||
auto dim_size = dim.size();
|
||||
// Optimisation for two common cases
|
||||
if (dim_size == 0) {
|
||||
return t;
|
||||
} else if (dim_size == 1) {
|
||||
return t.unsqueeze(dim[0]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
auto dims_to_unsqueeze = at::dim_list_to_bitset(opt_dim, n_dims);
|
||||
Tensor res = t;
|
||||
for (const auto i : c10::irange(n_dims)) {
|
||||
if (dims_to_unsqueeze[i]) {
|
||||
res = res.unsqueeze(i);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
Tensor sum_backward(
|
||||
const Tensor& grad,
|
||||
IntArrayRef sizes,
|
||||
IntArrayRef dims,
|
||||
OptionalIntArrayRef opt_dims,
|
||||
bool keepdim) {
|
||||
if (!keepdim && sizes.size() > 0 && dims.size() > 0) {
|
||||
return unsqueeze_multiple(grad, dims, sizes.size()).expand(sizes);
|
||||
} else {
|
||||
return grad.expand(sizes);
|
||||
if (!keepdim && sizes.size() > 0) {
|
||||
if (opt_dims.has_value() && opt_dims.value().size() > 0) {
|
||||
return unsqueeze_multiple(grad, opt_dims, sizes.size()).expand(sizes);
|
||||
}
|
||||
}
|
||||
return grad.expand(sizes);
|
||||
}
|
||||
|
||||
Tensor nansum_backward(
|
||||
|
|
|
|||
|
|
@ -150,12 +150,12 @@ at::Tensor rad2deg_backward(const at::Tensor& grad);
|
|||
at::Tensor deg2rad_backward(const at::Tensor& grad);
|
||||
at::Tensor unsqueeze_multiple(
|
||||
const at::Tensor& t,
|
||||
at::IntArrayRef dim,
|
||||
at::OptionalIntArrayRef opt_dim,
|
||||
size_t n_dims);
|
||||
at::Tensor sum_backward(
|
||||
const at::Tensor& grad,
|
||||
at::IntArrayRef sizes,
|
||||
at::IntArrayRef dims,
|
||||
at::OptionalIntArrayRef opt_dims,
|
||||
bool keepdim);
|
||||
at::Tensor nansum_backward(
|
||||
const at::Tensor& grad,
|
||||
|
|
|
|||
|
|
@ -2480,7 +2480,7 @@ class IrParser {
|
|||
|
||||
{
|
||||
auto ptr_op = getOperatorForLiteral(
|
||||
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
|
||||
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
|
||||
REGISTER_PARSE_RULE(
|
||||
ptr_op,
|
||||
{
|
||||
|
|
@ -3857,7 +3857,7 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
|
|||
|
||||
static auto reduction_operator_schema =
|
||||
getOperatorForLiteral(
|
||||
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)")
|
||||
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)")
|
||||
->schema();
|
||||
if (node->matches(reduction_operator_schema)) {
|
||||
switch (offset) {
|
||||
|
|
|
|||
|
|
@ -1980,7 +1980,7 @@ class ShapePropagator : public PropertyPropBase {
|
|||
return true;
|
||||
} else if (
|
||||
node->matches(
|
||||
"aten::sum(Tensor self, int[] dim, bool keepdim, *, int? dtype) -> Tensor",
|
||||
"aten::sum(Tensor self, int[]? dim, bool keepdim, *, int? dtype) -> Tensor",
|
||||
/*const_inputs=*/{attr::dim, attr::keepdim})) {
|
||||
auto& tp = tensor_types.at(0);
|
||||
auto sizes = tp->sizes().concrete_sizes().value();
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ bool isSupported(Node* node) {
|
|||
|
||||
static const OperatorSet supported_reduction_set{
|
||||
"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor",
|
||||
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
|
||||
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
|
||||
"aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor",
|
||||
"aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2158,6 +2158,54 @@ def transpose(self: List[int],
|
|||
_4 = torch.append(out, self[idx])
|
||||
return out
|
||||
|
||||
)=====")
|
||||
+ std::string(R"=====(def sum_dim(self: List[int],
|
||||
opt_dims: Optional[List[int]],
|
||||
keep_dim: bool,
|
||||
dt: Any) -> List[int]:
|
||||
out = annotate(List[int], [])
|
||||
if opt_dims is None:
|
||||
dims:List[int] = []
|
||||
else:
|
||||
dims = opt_dims
|
||||
for idx in range(torch.len(self)):
|
||||
is_mean_dim = False
|
||||
for _0 in range(torch.len(dims)):
|
||||
reduce_dim = dims[_0]
|
||||
_1 = torch.len(self)
|
||||
if torch.le(_1, 0):
|
||||
dim_post_expr = 1
|
||||
else:
|
||||
dim_post_expr = _1
|
||||
min = torch.neg(dim_post_expr)
|
||||
max = torch.sub(dim_post_expr, 1)
|
||||
if torch.lt(reduce_dim, min):
|
||||
_2 = True
|
||||
else:
|
||||
_2 = torch.gt(reduce_dim, max)
|
||||
if torch.__not__(_2):
|
||||
pass
|
||||
else:
|
||||
ops.prim.RaiseException("AssertionError: ")
|
||||
if torch.lt(reduce_dim, 0):
|
||||
dim0 = torch.add(reduce_dim, dim_post_expr)
|
||||
dim = dim0
|
||||
else:
|
||||
dim = reduce_dim
|
||||
if torch.eq(idx, dim):
|
||||
is_mean_dim0 = True
|
||||
else:
|
||||
is_mean_dim0 = is_mean_dim
|
||||
is_mean_dim = is_mean_dim0
|
||||
if is_mean_dim:
|
||||
if keep_dim:
|
||||
_3 = torch.append(out, 1)
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
_4 = torch.append(out, self[idx])
|
||||
return out
|
||||
|
||||
)=====")
|
||||
+ std::string(R"=====(def max_dim(self: List[int],
|
||||
dim: int,
|
||||
|
|
@ -2749,7 +2797,7 @@ const OperatorMap<std::string>& GetShapeFunctionMappings() {
|
|||
{"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "expand"},
|
||||
{"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "expand_one_unused"},
|
||||
{"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"},
|
||||
{"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"},
|
||||
{"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "sum_dim"},
|
||||
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "max_dim"},
|
||||
{"aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},
|
||||
{"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},
|
||||
|
|
|
|||
|
|
@ -1691,10 +1691,10 @@ REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
|
|||
};
|
||||
}
|
||||
if (n->matches(torch::schema(
|
||||
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
|
||||
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const at::Tensor& self = p_node->Input(0).toTensor();
|
||||
auto dim = p_node->Input(1).toIntList().vec();
|
||||
auto dim = p_node->Input(1).toDimVector();
|
||||
auto keepdim = p_node->Input(2).toBool();
|
||||
auto dtype = p_node->Input(3).toOptional<at::ScalarType>();
|
||||
if (p_node->Output(0).isNone()) {
|
||||
|
|
|
|||
|
|
@ -1767,7 +1767,7 @@ int nnc_lowerings_lazy_registration() {
|
|||
|
||||
RegisterNNCLoweringsFunction aten_sum(
|
||||
{"aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)",
|
||||
"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"},
|
||||
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"},
|
||||
computeSum);
|
||||
|
||||
RegisterNNCLoweringsFunction aten_softmax(
|
||||
|
|
|
|||
|
|
@ -1004,7 +1004,7 @@ add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)",
|
|||
add_shape_compute_mapping("aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand)
|
||||
add_shape_compute_mapping("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", expand_one_unused)
|
||||
add_shape_compute_mapping("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
|
||||
add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
|
||||
add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
|
||||
add_shape_compute_mapping("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", max_dim)
|
||||
add_shape_compute_mapping("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)
|
||||
add_shape_compute_mapping("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)
|
||||
|
|
|
|||
|
|
@ -18778,9 +18778,6 @@ op_db: List[OpInfo] = [
|
|||
# FIXME: sum reduces all dimensions when dim=[]
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
|
||||
# FIXME: sum does not support passing None to dim
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
|
||||
# FIXME: improve precision
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
|
||||
dtypes=[torch.float16]),
|
||||
|
|
|
|||
Loading…
Reference in a new issue