From e33f1eeeb73a8d680b8aae7944011389f76faaff Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 10 Dec 2022 20:23:17 -0800 Subject: [PATCH] SymIntify resize_ and deduplicate memory format logic (#90442) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/90442 Approved by: https://github.com/bdhirsh --- aten/src/ATen/EmptyTensor.cpp | 83 ++++------- aten/src/ATen/EmptyTensor.h | 4 + aten/src/ATen/native/Resize.cpp | 151 +++++++++++++++++++-- aten/src/ATen/native/Resize.h | 29 +--- aten/src/ATen/native/native_functions.yaml | 3 +- c10/core/MemoryFormat.h | 18 ++- c10/core/TensorImpl.cpp | 79 +++++++++++ c10/core/TensorImpl.h | 38 +++++- test/test_proxy_tensor.py | 12 ++ torch/csrc/autograd/VariableTypeManual.cpp | 4 +- 10 files changed, 314 insertions(+), 107 deletions(-) diff --git a/aten/src/ATen/EmptyTensor.cpp b/aten/src/ATen/EmptyTensor.cpp index daf0b684236..b59589d2bfb 100644 --- a/aten/src/ATen/EmptyTensor.cpp +++ b/aten/src/ATen/EmptyTensor.cpp @@ -106,6 +106,15 @@ size_t computeStorageNbytes( #endif } +SymInt computeStorageNbytesContiguous( + SymIntArrayRef sizes, + SymInt itemsize_bytes, + SymInt storage_offset + ) { + const auto numel = c10::multiply_integers(sizes); + return itemsize_bytes * (storage_offset + numel); +} + // not including mobile-only macros in this function, // since mobile shouldn't be using symints. SymInt computeStorageNbytes( @@ -135,8 +144,9 @@ SymInt computeStorageNbytes( return itemsize_bytes * (storage_offset + size); } -TensorBase empty_generic( - IntArrayRef size, +template +TensorBase _empty_generic( + ArrayRef size, c10::Allocator* allocator, c10::DispatchKeySet ks, ScalarType scalar_type, @@ -144,11 +154,10 @@ TensorBase empty_generic( at::detail::check_size_nonnegative(size); at::detail::raise_warning_for_complex_half(scalar_type); caffe2::TypeMeta dtype = scalarTypeToTypeMeta(scalar_type); - size_t size_bytes = computeStorageNbytesContiguous(size, dtype.itemsize()); + auto size_bytes = computeStorageNbytesContiguous(size, dtype.itemsize()); auto storage_impl = c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), size_bytes, - allocator->allocate(size_bytes), allocator, /*resizeable=*/true); @@ -156,7 +165,7 @@ TensorBase empty_generic( std::move(storage_impl), ks, dtype); // Default TensorImpl has size [0] if (size.size() != 1 || size[0] != 0) { - tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); + tensor.unsafeGetTensorImpl()->generic_set_sizes_contiguous(size); } if (memory_format_opt.has_value()) { @@ -169,6 +178,15 @@ TensorBase empty_generic( return tensor; } +TensorBase empty_generic( + IntArrayRef size, + c10::Allocator* allocator, + c10::DispatchKeySet ks, + ScalarType scalar_type, + c10::optional memory_format_opt) { + return _empty_generic(size, allocator, ks, scalar_type, memory_format_opt); +} + template TensorBase _empty_strided_generic( T size, @@ -338,59 +356,10 @@ TensorBase empty_symint_meta( c10::optional pin_memory_opt, c10::optional memory_format_opt ) { - auto device = device_or_default(device_opt); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::Meta); - // NB: because there is no SparseMeta (yet), non-strided layout is - // exerciseable - TORCH_CHECK_NOT_IMPLEMENTED( - layout_or_default(layout_opt) == Layout::Strided, - "non-strided meta tensors not supported yet" - ); - - auto scalar_type = dtype_or_default(dtype_opt); auto *allocator = GetAllocator(kMeta); - constexpr c10::DispatchKeySet meta_dks(c10::DispatchKey::Meta); - at::detail::check_size_nonnegative(size); - at::detail::raise_warning_for_complex_half(scalar_type); - caffe2::TypeMeta dtype = scalarTypeToTypeMeta(scalar_type); - SymInt size_bytes = dtype.itemsize(); - for (auto s : size) { - size_bytes = size_bytes * s; - } - auto storage_impl = c10::make_intrusive( - c10::StorageImpl::use_byte_size_t(), - size_bytes, - allocator, - /*resizeable=*/true); - - auto tensor = detail::make_tensor_base( - std::move(storage_impl), meta_dks, dtype); - - int64_t dim = size.size(); - std::vector strides; - strides.resize(dim); - - // TODO: Move this into TensorImpl - auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous); - switch (memory_format) { - case MemoryFormat::Contiguous: { - if (dim > 0) { - const auto last_idx = dim - 1; - strides.at(last_idx) = 1; - for (auto i = last_idx - 1; i >= 0; --i) { - // TODO: max with 1 - strides.at(i) = strides.at(i+1) * size.at(i+1); - } - } - break; - } - default: - TORCH_CHECK(0, "other memory format not implemented yet"); - } - - tensor.unsafeGetTensorImpl()->set_sizes_and_strides(size, strides); - - return tensor; + constexpr c10::DispatchKeySet ks(c10::DispatchKey::Meta); + auto scalar_type = dtype_or_default(dtype_opt); + return _empty_generic(size, allocator, ks, scalar_type, memory_format_opt); } TensorBase empty_meta( diff --git a/aten/src/ATen/EmptyTensor.h b/aten/src/ATen/EmptyTensor.h index 969eeb6dc5e..b09c32e59e7 100644 --- a/aten/src/ATen/EmptyTensor.h +++ b/aten/src/ATen/EmptyTensor.h @@ -20,6 +20,10 @@ TORCH_API size_t computeStorageNbytesContiguous( IntArrayRef sizes, size_t itemsize, size_t storage_offset = 0); +TORCH_API SymInt computeStorageNbytesContiguous( + SymIntArrayRef sizes, + SymInt itemsize, + SymInt storage_offset = 0); TORCH_API size_t computeStorageNbytes( IntArrayRef sizes, IntArrayRef strides, diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index bd47a25e696..abb58337d3a 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -10,20 +10,22 @@ #else #include #include +#include #endif namespace at { namespace native { // Returns true if resize is necessary -bool resize_output_check(const Tensor& output, IntArrayRef shape) { +template +bool _resize_output_check(const Tensor& output, ArrayRef shape) { // Tests for resizing of tensors with one or more elements - if (output.sizes().equals(shape)) { + if (at::symint::sizes(output).equals(shape)) { return false; } - if (output.numel() != 0) { + if (at::symint::numel(output) != 0) { TORCH_WARN( "An output with one or more elements was resized since it had ", - "shape ", output.sizes(), ", which does not match the required ", + "shape ", at::symint::sizes(output), ", which does not match the required ", "output shape ", shape, ". ", "This behavior is deprecated, and in a future PyTorch release outputs ", "will not be resized unless they have zero elements. You can explicitly ", @@ -33,8 +35,25 @@ bool resize_output_check(const Tensor& output, IntArrayRef shape) { return true; } -bool resize_output(const Tensor& output, IntArrayRef shape) { - if (resize_output_check(output, shape)) { +bool resize_output_check(const Tensor& output, IntArrayRef shape) { + return _resize_output_check(output, shape); +} + +bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape) { + return _resize_output_check(output, shape); +} + +void native_resize_(const Tensor& output, IntArrayRef shape) { + native::resize_(output, shape); +} + +void native_resize_(const Tensor& output, SymIntArrayRef shape) { + native::resize__symint(output, shape); +} + +template +bool _resize_output(const Tensor& output, ArrayRef shape) { + if (_resize_output_check(output, shape)) { // avoid a redispatch for cpu and cuda. // TODO: when resize_cuda_ is re-written to be unified with resize_, // we can provide the same benefit for cuda. @@ -42,9 +61,9 @@ bool resize_output(const Tensor& output, IntArrayRef shape) { // TODO(#61485): functorch wrapped tensors should not go through the // fast path. This is a hack, longer term solutions are in the issue if (output.is_cpu() && !isTensorSubclassLike(output)) { - at::native::resize_(output, shape); + native_resize_(output, shape); } else { - output.resize_(shape); + at::symint::resize_(output, shape); } return true; } else { @@ -52,6 +71,14 @@ bool resize_output(const Tensor& output, IntArrayRef shape) { } } +bool resize_output(const Tensor& output, IntArrayRef shape) { + return _resize_output(output, shape); +} + +bool resize_output_symint(const Tensor& output, SymIntArrayRef shape) { + return _resize_output(output, shape); +} + const Tensor& _resize_output_(const Tensor& self, IntArrayRef shape, c10::Device device) { TORCH_CHECK(self.device() == device, "out Tensor doesn't have the correct device set"); at::native::resize_output(self, shape); @@ -126,16 +153,92 @@ const Tensor& resize_as_( return result; } -const Tensor& resize_( - const Tensor& self, - IntArrayRef size, - c10::optional optional_memory_format) { - if (self.has_names()) { - return resize_named_tensor_(self, size, optional_memory_format); + +void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes) { + TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable"); + storage->set_nbytes(size_bytes); +} + +static void maybe_resize_storage_meta(TensorImpl* self, c10::SymInt new_size_bytes) { + // It does not make sense to try to resize a storage + // to hold 0 elements, and this can break + // if storage_offset is positive but + // new_size is 0, so just bail in that case + // (same comment is in Resize.h) + if (self->sym_numel() == 0) { + return; } + + const Storage& storage = self->unsafe_storage(); + if (!storage) { + TORCH_INTERNAL_ASSERT(0, "NYI, this should only be Caffe2"); + } else if (new_size_bytes > storage.nbytes()) { + resize_bytes_meta(storage.unsafeGetStorageImpl(), new_size_bytes); + } +} + +static void _maybe_resize_storage(TensorImpl* self, int64_t new_size_bytes) { + maybe_resize_storage_cpu(self, new_size_bytes); +} + +static void _maybe_resize_storage(TensorImpl* self, c10::SymInt new_size_bytes) { + maybe_resize_storage_meta(self, new_size_bytes); +} + +template +TensorImpl* _resize_impl_( + TensorImpl* self, + ArrayRef size, + at::OptionalArrayRef stride, + bool resize_storage) { + if (self->generic_sizes() == size && (!stride || self->generic_strides() == stride.value())) { + return self; + } + + const auto itemsize = self->dtype().itemsize(); + const auto storage_offset = self->generic_storage_offset(); + T storage_size = T(1); + if (stride) { + self->set_sizes_and_strides(size, *stride); + storage_size = at::detail::computeStorageNbytes( + size, *stride, itemsize, storage_offset); + } else { + self->generic_set_sizes_contiguous(size); + storage_size = at::detail::computeStorageNbytesContiguous( + size, itemsize, storage_offset); + } + + if (resize_storage) { + _maybe_resize_storage(self, storage_size); + } + + return self; +} + +TensorImpl* resize_impl_cpu_( + TensorImpl* self, + IntArrayRef size, + at::OptionalIntArrayRef stride, + bool resize_storage) { + return _resize_impl_(self, size, stride, resize_storage); +} + +TensorImpl* resize_impl_meta_( + TensorImpl* self, + c10::SymIntArrayRef size, + at::OptionalSymIntArrayRef stride, + bool resize_storage = true) { + return _resize_impl_(self, size, stride, resize_storage); +} + +template +const Tensor& _resize_( + const Tensor& self, + ArrayRef size, + c10::optional optional_memory_format) { auto* self_ = self.unsafeGetTensorImpl(); // NOLINTNEXTLINE(bugprone-argument-comment) - resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt); + _resize_impl_(self_, size, /*strides=*/c10::nullopt, true); if (optional_memory_format.has_value()) { auto memory_format = optional_memory_format.value(); @@ -148,5 +251,23 @@ const Tensor& resize_( return self; } +const Tensor& resize_( + const Tensor& self, + IntArrayRef size, + c10::optional optional_memory_format) { + if (self.has_names()) { + return resize_named_tensor_(self, size, optional_memory_format); + } + return _resize_(self, size, optional_memory_format); +} + +const Tensor& resize__symint( + const Tensor& self, + c10::SymIntArrayRef size, + c10::optional optional_memory_format) { + TORCH_INTERNAL_ASSERT(!self.has_names()) + return _resize_(self, size, optional_memory_format); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/Resize.h b/aten/src/ATen/native/Resize.h index 0bed4232695..d7408ba2229 100644 --- a/aten/src/ATen/native/Resize.h +++ b/aten/src/ATen/native/Resize.h @@ -23,11 +23,13 @@ namespace at { namespace native { // NOTE: In the future the warning will become an error // Returns a bool saying whether or not the resize actually happened or not TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape); +TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape); // Utility for resize_output // Returns a bool saying resize should happen or not and // raises a warning if resizing for one or more elements TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape); +TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape); TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes); @@ -54,34 +56,11 @@ static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_by } } -inline TensorImpl* resize_impl_cpu_( +TORCH_API TensorImpl* resize_impl_cpu_( TensorImpl* self, IntArrayRef size, at::OptionalIntArrayRef stride, - bool resize_storage = true) { - if (self->sizes() == size && (!stride || self->strides() == stride.value())) { - return self; - } - - const auto itemsize = self->dtype().itemsize(); - const auto storage_offset = self->storage_offset(); - size_t storage_size = 1; - if (stride) { - self->set_sizes_and_strides(size, *stride); - storage_size = at::detail::computeStorageNbytes( - size, *stride, itemsize, storage_offset); - } else { - self->set_sizes_contiguous(size); - storage_size = at::detail::computeStorageNbytesContiguous( - size, itemsize, storage_offset); - } - - if (resize_storage) { - maybe_resize_storage_cpu(self, storage_size); - } - - return self; -} + bool resize_storage = true); template T maybe_convert_symint(c10::SymInt) = delete; diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index fa0cab8f1e5..2390dbd582d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2284,7 +2284,8 @@ device_guard: False tags: inplace_view dispatch: - CPU, Meta: resize_ + Meta: resize__symint + CPU: resize_ CUDA: resize_cuda_ MPS: resize_mps_ QuantizedCPU: quantized_resize_cpu_ diff --git a/c10/core/MemoryFormat.h b/c10/core/MemoryFormat.h index b1363033ca4..f4e3e930790 100644 --- a/c10/core/MemoryFormat.h +++ b/c10/core/MemoryFormat.h @@ -61,8 +61,9 @@ inline std::ostream& operator<<( // Note: Hardcoded the channel last stride indices here to get better // performance -inline std::vector get_channels_last_strides_2d(IntArrayRef sizes) { - std::vector strides(sizes.size()); +template +inline std::vector get_channels_last_strides_2d(ArrayRef sizes) { + std::vector strides(sizes.size()); switch (sizes.size()) { case 4: strides[1] = 1; @@ -81,8 +82,13 @@ inline std::vector get_channels_last_strides_2d(IntArrayRef sizes) { } } -inline std::vector get_channels_last_strides_3d(IntArrayRef sizes) { - std::vector strides(sizes.size()); +inline std::vector get_channels_last_strides_2d(IntArrayRef sizes) { + return get_channels_last_strides_2d(sizes); +} + +template +std::vector get_channels_last_strides_3d(ArrayRef sizes) { + std::vector strides(sizes.size()); switch (sizes.size()) { case 5: strides[1] = 1; @@ -103,6 +109,10 @@ inline std::vector get_channels_last_strides_3d(IntArrayRef sizes) { } } +inline std::vector get_channels_last_strides_3d(IntArrayRef sizes) { + return get_channels_last_strides_3d(sizes); +} + // NOTE: // Below are Helper functions for is_channels_last_strides_xd. // 1. Please do not combine these helper functions, each helper function handles diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index bee3fa32ec2..fccf2dd1094 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -993,6 +993,10 @@ void TensorImpl::set_sizes_and_strides( set_storage_offset(storage_offset->as_int_unchecked()); return; } + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_sizes_and_strides ", + err_msg_tensor_metadata_change_not_allowed); has_symbolic_sizes_strides_ = true; refresh_sizes_strides_policy(); @@ -1011,6 +1015,81 @@ void TensorImpl::set_sizes_and_strides( refresh_contiguous(); } +void TensorImpl::generic_set_sizes_contiguous(SymIntArrayRef sizes) { + auto int_sizes = asIntArrayRefSlowOpt(sizes); + if (int_sizes.has_value()) { + set_sizes_contiguous(*int_sizes); + return; + } + + TORCH_CHECK( + allow_tensor_metadata_change(), + "generic_set_sizes_contiguous ", + err_msg_tensor_metadata_change_not_allowed); + + has_symbolic_sizes_strides_ = true; + refresh_sizes_strides_policy(); + if (!extra_meta_) { + extra_meta_ = std::make_unique(); + extra_meta_->storage_offset_ = storage_offset_; + } + + clone_symvec(sizes, extra_meta_->sizes_); + refresh_numel(); + empty_tensor_restride_symint( + MemoryFormat::Contiguous); // calls refresh_contiguous() +} + +void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) { + TORCH_INTERNAL_ASSERT(has_symbolic_sizes_strides_); +#ifdef DEBUG + TORCH_INTERNAL_ASSERT( + compute_numel() == numel_, + "If you are seeing this error, that means empty_tensor_restride was " + "called before setting correct numel"); +#endif + switch (memory_format) { + case MemoryFormat::Contiguous: { + // dim_ is a virtual call, don't repeat it + const auto dim_ = dim(); + extra_meta_->strides_.resize(dim_); + if (dim_ > 0) { + const auto last_idx = dim_ - 1; + extra_meta_->strides_[last_idx] = c10::SymInt(1); + for (auto i = last_idx - 1; i >= 0; --i) { + extra_meta_->strides_[last_idx] = + extra_meta_->strides_[i + 1] * extra_meta_->sizes_[i + 1].max(1); + } + } + break; + } + case MemoryFormat::ChannelsLast: { + TORCH_CHECK( + dim() == 4, "required rank 4 tensor to use channels_last format"); + set_sizes_and_strides( + sym_sizes(), get_channels_last_strides_2d(sym_sizes())); + break; + } + case MemoryFormat::ChannelsLast3d: { + TORCH_CHECK( + dim() == 5, "required rank 5 tensor to use channels_last_3d format"); + set_sizes_and_strides( + sym_sizes(), get_channels_last_strides_3d(sym_sizes())); + break; + } + case MemoryFormat::Preserve: + TORCH_CHECK(false, "unsupported memory format ", memory_format); + // Cleaning warning messages, no need to break as TORCH_CHECK(false) + // terminates flow. + // break; + case MemoryFormat::NumOptions: + TORCH_INTERNAL_ASSERT(false, "invalid memory format ", memory_format); + } + // recompute contiguous flag, as currently NHWC/NCHW flags are not mutually + // exclusive see #24090 + refresh_contiguous(); +} + namespace impl { namespace { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index a6ba3f16e2a..3a0cce80991 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -692,6 +692,30 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return sym_sizes(); } + template + ArrayRef generic_strides() { + return _generic_strides(identity()); + } + + ArrayRef _generic_strides(identity) { + return strides(); + } + ArrayRef _generic_strides(identity) { + return sym_strides(); + } + + template + T generic_storage_offset() { + return _generic_storage_offset(identity()); + } + + int64_t _generic_storage_offset(identity) { + return storage_offset(); + } + c10::SymInt _generic_storage_offset(identity) { + return sym_storage_offset(); + } + /** * The number of elements in a tensor. * @@ -1604,6 +1628,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides, c10::optional storage_offset = c10::nullopt); + // This is renamed to avoid breaking overload BC + void generic_set_sizes_contiguous(c10::SymIntArrayRef sizes); + void generic_set_sizes_contiguous(c10::IntArrayRef sizes) { + set_sizes_contiguous(sizes); + } /** * Change the size at some dimension. This DOES NOT update strides; @@ -2311,6 +2340,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { data_type_ = data_type; } + void empty_tensor_restride_symint(MemoryFormat memory_format); + /** * Set the strides of the tensor to match memory_format * @@ -2318,9 +2349,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * memory contiguous */ void empty_tensor_restride(MemoryFormat memory_format) { - TORCH_CHECK( - !has_symbolic_sizes_strides_, - "empty_tensor_restride() called on tensor with symbolic shape") + if (has_symbolic_sizes_strides_) { + empty_tensor_restride_symint(memory_format); + return; + } #ifdef DEBUG TORCH_INTERNAL_ASSERT( compute_numel() == numel_, diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 8bc6d3235ef..db625fbbac0 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -798,6 +798,18 @@ class TestSymbolicTracing(TestCase): return traced_f + def test_resize_from_zero(self): + def f(x, y): + x.resize_(y.size(0)) + + r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip() + self.assertExpectedInline(r, """\ +def forward(self, x_1, y_1): + sym_size = torch.ops.aten.sym_size(y_1, 0); y_1 = None + resize_ = torch.ops.aten.resize_.default(x_1, [sym_size]); x_1 = sym_size = None + return None""") + + def test_unary(self): def f(x): assert x.shape[0] < 20 diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index e276521aceb..101d8d9b219 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -226,7 +226,7 @@ Tensor& copy_( const Tensor& resize_( c10::DispatchKeySet ks, const Tensor& self, - IntArrayRef size, + SymIntArrayRef size, c10::optional optional_memory_format) { auto& self_ = unpack(self, "self", 0); if (self.requires_grad()) { @@ -234,7 +234,7 @@ const Tensor& resize_( } { at::AutoDispatchBelowAutograd mode; - at::redispatch::resize_( + at::redispatch::resize__symint( ks & c10::after_autograd_keyset, self_, size, optional_memory_format); }