SymIntify resize_ and deduplicate memory format logic (#90442)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90442
Approved by: https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang 2022-12-10 20:23:17 -08:00 committed by PyTorch MergeBot
parent 181d37475d
commit e33f1eeeb7
10 changed files with 314 additions and 107 deletions

View file

@ -106,6 +106,15 @@ size_t computeStorageNbytes(
#endif
}
SymInt computeStorageNbytesContiguous(
SymIntArrayRef sizes,
SymInt itemsize_bytes,
SymInt storage_offset
) {
const auto numel = c10::multiply_integers(sizes);
return itemsize_bytes * (storage_offset + numel);
}
// not including mobile-only macros in this function,
// since mobile shouldn't be using symints.
SymInt computeStorageNbytes(
@ -135,8 +144,9 @@ SymInt computeStorageNbytes(
return itemsize_bytes * (storage_offset + size);
}
TensorBase empty_generic(
IntArrayRef size,
template <typename T>
TensorBase _empty_generic(
ArrayRef<T> size,
c10::Allocator* allocator,
c10::DispatchKeySet ks,
ScalarType scalar_type,
@ -144,11 +154,10 @@ TensorBase empty_generic(
at::detail::check_size_nonnegative(size);
at::detail::raise_warning_for_complex_half(scalar_type);
caffe2::TypeMeta dtype = scalarTypeToTypeMeta(scalar_type);
size_t size_bytes = computeStorageNbytesContiguous(size, dtype.itemsize());
auto size_bytes = computeStorageNbytesContiguous(size, dtype.itemsize());
auto storage_impl = c10::make_intrusive<StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
size_bytes,
allocator->allocate(size_bytes),
allocator,
/*resizeable=*/true);
@ -156,7 +165,7 @@ TensorBase empty_generic(
std::move(storage_impl), ks, dtype);
// Default TensorImpl has size [0]
if (size.size() != 1 || size[0] != 0) {
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
tensor.unsafeGetTensorImpl()->generic_set_sizes_contiguous(size);
}
if (memory_format_opt.has_value()) {
@ -169,6 +178,15 @@ TensorBase empty_generic(
return tensor;
}
TensorBase empty_generic(
IntArrayRef size,
c10::Allocator* allocator,
c10::DispatchKeySet ks,
ScalarType scalar_type,
c10::optional<c10::MemoryFormat> memory_format_opt) {
return _empty_generic(size, allocator, ks, scalar_type, memory_format_opt);
}
template <typename T>
TensorBase _empty_strided_generic(
T size,
@ -338,59 +356,10 @@ TensorBase empty_symint_meta(
c10::optional<bool> pin_memory_opt,
c10::optional<c10::MemoryFormat> memory_format_opt
) {
auto device = device_or_default(device_opt);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::Meta);
// NB: because there is no SparseMeta (yet), non-strided layout is
// exerciseable
TORCH_CHECK_NOT_IMPLEMENTED(
layout_or_default(layout_opt) == Layout::Strided,
"non-strided meta tensors not supported yet"
);
auto scalar_type = dtype_or_default(dtype_opt);
auto *allocator = GetAllocator(kMeta);
constexpr c10::DispatchKeySet meta_dks(c10::DispatchKey::Meta);
at::detail::check_size_nonnegative(size);
at::detail::raise_warning_for_complex_half(scalar_type);
caffe2::TypeMeta dtype = scalarTypeToTypeMeta(scalar_type);
SymInt size_bytes = dtype.itemsize();
for (auto s : size) {
size_bytes = size_bytes * s;
}
auto storage_impl = c10::make_intrusive<StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
size_bytes,
allocator,
/*resizeable=*/true);
auto tensor = detail::make_tensor_base<TensorImpl>(
std::move(storage_impl), meta_dks, dtype);
int64_t dim = size.size();
std::vector<SymInt> strides;
strides.resize(dim);
// TODO: Move this into TensorImpl
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
switch (memory_format) {
case MemoryFormat::Contiguous: {
if (dim > 0) {
const auto last_idx = dim - 1;
strides.at(last_idx) = 1;
for (auto i = last_idx - 1; i >= 0; --i) {
// TODO: max with 1
strides.at(i) = strides.at(i+1) * size.at(i+1);
}
}
break;
}
default:
TORCH_CHECK(0, "other memory format not implemented yet");
}
tensor.unsafeGetTensorImpl()->set_sizes_and_strides(size, strides);
return tensor;
constexpr c10::DispatchKeySet ks(c10::DispatchKey::Meta);
auto scalar_type = dtype_or_default(dtype_opt);
return _empty_generic(size, allocator, ks, scalar_type, memory_format_opt);
}
TensorBase empty_meta(

View file

@ -20,6 +20,10 @@ TORCH_API size_t computeStorageNbytesContiguous(
IntArrayRef sizes,
size_t itemsize,
size_t storage_offset = 0);
TORCH_API SymInt computeStorageNbytesContiguous(
SymIntArrayRef sizes,
SymInt itemsize,
SymInt storage_offset = 0);
TORCH_API size_t computeStorageNbytes(
IntArrayRef sizes,
IntArrayRef strides,

View file

@ -10,20 +10,22 @@
#else
#include <ATen/ops/resize_as_native.h>
#include <ATen/ops/resize_native.h>
#include <ATen/ops/resize.h>
#endif
namespace at { namespace native {
// Returns true if resize is necessary
bool resize_output_check(const Tensor& output, IntArrayRef shape) {
template <typename T>
bool _resize_output_check(const Tensor& output, ArrayRef<T> shape) {
// Tests for resizing of tensors with one or more elements
if (output.sizes().equals(shape)) {
if (at::symint::sizes<T>(output).equals(shape)) {
return false;
}
if (output.numel() != 0) {
if (at::symint::numel<T>(output) != 0) {
TORCH_WARN(
"An output with one or more elements was resized since it had ",
"shape ", output.sizes(), ", which does not match the required ",
"shape ", at::symint::sizes<T>(output), ", which does not match the required ",
"output shape ", shape, ". ",
"This behavior is deprecated, and in a future PyTorch release outputs ",
"will not be resized unless they have zero elements. You can explicitly ",
@ -33,8 +35,25 @@ bool resize_output_check(const Tensor& output, IntArrayRef shape) {
return true;
}
bool resize_output(const Tensor& output, IntArrayRef shape) {
if (resize_output_check(output, shape)) {
bool resize_output_check(const Tensor& output, IntArrayRef shape) {
return _resize_output_check(output, shape);
}
bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape) {
return _resize_output_check(output, shape);
}
void native_resize_(const Tensor& output, IntArrayRef shape) {
native::resize_(output, shape);
}
void native_resize_(const Tensor& output, SymIntArrayRef shape) {
native::resize__symint(output, shape);
}
template <typename T>
bool _resize_output(const Tensor& output, ArrayRef<T> shape) {
if (_resize_output_check<T>(output, shape)) {
// avoid a redispatch for cpu and cuda.
// TODO: when resize_cuda_ is re-written to be unified with resize_,
// we can provide the same benefit for cuda.
@ -42,9 +61,9 @@ bool resize_output(const Tensor& output, IntArrayRef shape) {
// TODO(#61485): functorch wrapped tensors should not go through the
// fast path. This is a hack, longer term solutions are in the issue
if (output.is_cpu() && !isTensorSubclassLike(output)) {
at::native::resize_(output, shape);
native_resize_(output, shape);
} else {
output.resize_(shape);
at::symint::resize_<T>(output, shape);
}
return true;
} else {
@ -52,6 +71,14 @@ bool resize_output(const Tensor& output, IntArrayRef shape) {
}
}
bool resize_output(const Tensor& output, IntArrayRef shape) {
return _resize_output(output, shape);
}
bool resize_output_symint(const Tensor& output, SymIntArrayRef shape) {
return _resize_output(output, shape);
}
const Tensor& _resize_output_(const Tensor& self, IntArrayRef shape, c10::Device device) {
TORCH_CHECK(self.device() == device, "out Tensor doesn't have the correct device set");
at::native::resize_output(self, shape);
@ -126,16 +153,92 @@ const Tensor& resize_as_(
return result;
}
const Tensor& resize_(
const Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
if (self.has_names()) {
return resize_named_tensor_(self, size, optional_memory_format);
void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes) {
TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
storage->set_nbytes(size_bytes);
}
static void maybe_resize_storage_meta(TensorImpl* self, c10::SymInt new_size_bytes) {
// It does not make sense to try to resize a storage
// to hold 0 elements, and this can break
// if storage_offset is positive but
// new_size is 0, so just bail in that case
// (same comment is in Resize.h)
if (self->sym_numel() == 0) {
return;
}
const Storage& storage = self->unsafe_storage();
if (!storage) {
TORCH_INTERNAL_ASSERT(0, "NYI, this should only be Caffe2");
} else if (new_size_bytes > storage.nbytes()) {
resize_bytes_meta(storage.unsafeGetStorageImpl(), new_size_bytes);
}
}
static void _maybe_resize_storage(TensorImpl* self, int64_t new_size_bytes) {
maybe_resize_storage_cpu(self, new_size_bytes);
}
static void _maybe_resize_storage(TensorImpl* self, c10::SymInt new_size_bytes) {
maybe_resize_storage_meta(self, new_size_bytes);
}
template <typename T>
TensorImpl* _resize_impl_(
TensorImpl* self,
ArrayRef<T> size,
at::OptionalArrayRef<T> stride,
bool resize_storage) {
if (self->generic_sizes<T>() == size && (!stride || self->generic_strides<T>() == stride.value())) {
return self;
}
const auto itemsize = self->dtype().itemsize();
const auto storage_offset = self->generic_storage_offset<T>();
T storage_size = T(1);
if (stride) {
self->set_sizes_and_strides(size, *stride);
storage_size = at::detail::computeStorageNbytes(
size, *stride, itemsize, storage_offset);
} else {
self->generic_set_sizes_contiguous(size);
storage_size = at::detail::computeStorageNbytesContiguous(
size, itemsize, storage_offset);
}
if (resize_storage) {
_maybe_resize_storage(self, storage_size);
}
return self;
}
TensorImpl* resize_impl_cpu_(
TensorImpl* self,
IntArrayRef size,
at::OptionalIntArrayRef stride,
bool resize_storage) {
return _resize_impl_(self, size, stride, resize_storage);
}
TensorImpl* resize_impl_meta_(
TensorImpl* self,
c10::SymIntArrayRef size,
at::OptionalSymIntArrayRef stride,
bool resize_storage = true) {
return _resize_impl_(self, size, stride, resize_storage);
}
template <typename T>
const Tensor& _resize_(
const Tensor& self,
ArrayRef<T> size,
c10::optional<MemoryFormat> optional_memory_format) {
auto* self_ = self.unsafeGetTensorImpl();
// NOLINTNEXTLINE(bugprone-argument-comment)
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt);
_resize_impl_<T>(self_, size, /*strides=*/c10::nullopt, true);
if (optional_memory_format.has_value()) {
auto memory_format =
optional_memory_format.value();
@ -148,5 +251,23 @@ const Tensor& resize_(
return self;
}
const Tensor& resize_(
const Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
if (self.has_names()) {
return resize_named_tensor_(self, size, optional_memory_format);
}
return _resize_(self, size, optional_memory_format);
}
const Tensor& resize__symint(
const Tensor& self,
c10::SymIntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
TORCH_INTERNAL_ASSERT(!self.has_names())
return _resize_(self, size, optional_memory_format);
}
} // namespace native
} // namespace at

View file

@ -23,11 +23,13 @@ namespace at { namespace native {
// NOTE: In the future the warning will become an error
// Returns a bool saying whether or not the resize actually happened or not
TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
// Utility for resize_output
// Returns a bool saying resize should happen or not and
// raises a warning if resizing for one or more elements
TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
@ -54,34 +56,11 @@ static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_by
}
}
inline TensorImpl* resize_impl_cpu_(
TORCH_API TensorImpl* resize_impl_cpu_(
TensorImpl* self,
IntArrayRef size,
at::OptionalIntArrayRef stride,
bool resize_storage = true) {
if (self->sizes() == size && (!stride || self->strides() == stride.value())) {
return self;
}
const auto itemsize = self->dtype().itemsize();
const auto storage_offset = self->storage_offset();
size_t storage_size = 1;
if (stride) {
self->set_sizes_and_strides(size, *stride);
storage_size = at::detail::computeStorageNbytes(
size, *stride, itemsize, storage_offset);
} else {
self->set_sizes_contiguous(size);
storage_size = at::detail::computeStorageNbytesContiguous(
size, itemsize, storage_offset);
}
if (resize_storage) {
maybe_resize_storage_cpu(self, storage_size);
}
return self;
}
bool resize_storage = true);
template <typename T>
T maybe_convert_symint(c10::SymInt) = delete;

View file

@ -2284,7 +2284,8 @@
device_guard: False
tags: inplace_view
dispatch:
CPU, Meta: resize_
Meta: resize__symint
CPU: resize_
CUDA: resize_cuda_
MPS: resize_mps_
QuantizedCPU: quantized_resize_cpu_

View file

@ -61,8 +61,9 @@ inline std::ostream& operator<<(
// Note: Hardcoded the channel last stride indices here to get better
// performance
inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
std::vector<int64_t> strides(sizes.size());
template <typename T>
inline std::vector<T> get_channels_last_strides_2d(ArrayRef<T> sizes) {
std::vector<T> strides(sizes.size());
switch (sizes.size()) {
case 4:
strides[1] = 1;
@ -81,8 +82,13 @@ inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
}
}
inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
std::vector<int64_t> strides(sizes.size());
inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
return get_channels_last_strides_2d<int64_t>(sizes);
}
template <typename T>
std::vector<T> get_channels_last_strides_3d(ArrayRef<T> sizes) {
std::vector<T> strides(sizes.size());
switch (sizes.size()) {
case 5:
strides[1] = 1;
@ -103,6 +109,10 @@ inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
}
}
inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
return get_channels_last_strides_3d<int64_t>(sizes);
}
// NOTE:
// Below are Helper functions for is_channels_last_strides_xd.
// 1. Please do not combine these helper functions, each helper function handles

View file

@ -993,6 +993,10 @@ void TensorImpl::set_sizes_and_strides(
set_storage_offset(storage_offset->as_int_unchecked());
return;
}
TORCH_CHECK(
allow_tensor_metadata_change(),
"set_sizes_and_strides ",
err_msg_tensor_metadata_change_not_allowed);
has_symbolic_sizes_strides_ = true;
refresh_sizes_strides_policy();
@ -1011,6 +1015,81 @@ void TensorImpl::set_sizes_and_strides(
refresh_contiguous();
}
void TensorImpl::generic_set_sizes_contiguous(SymIntArrayRef sizes) {
auto int_sizes = asIntArrayRefSlowOpt(sizes);
if (int_sizes.has_value()) {
set_sizes_contiguous(*int_sizes);
return;
}
TORCH_CHECK(
allow_tensor_metadata_change(),
"generic_set_sizes_contiguous ",
err_msg_tensor_metadata_change_not_allowed);
has_symbolic_sizes_strides_ = true;
refresh_sizes_strides_policy();
if (!extra_meta_) {
extra_meta_ = std::make_unique<ExtraMeta>();
extra_meta_->storage_offset_ = storage_offset_;
}
clone_symvec(sizes, extra_meta_->sizes_);
refresh_numel();
empty_tensor_restride_symint(
MemoryFormat::Contiguous); // calls refresh_contiguous()
}
void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
TORCH_INTERNAL_ASSERT(has_symbolic_sizes_strides_);
#ifdef DEBUG
TORCH_INTERNAL_ASSERT(
compute_numel() == numel_,
"If you are seeing this error, that means empty_tensor_restride was "
"called before setting correct numel");
#endif
switch (memory_format) {
case MemoryFormat::Contiguous: {
// dim_ is a virtual call, don't repeat it
const auto dim_ = dim();
extra_meta_->strides_.resize(dim_);
if (dim_ > 0) {
const auto last_idx = dim_ - 1;
extra_meta_->strides_[last_idx] = c10::SymInt(1);
for (auto i = last_idx - 1; i >= 0; --i) {
extra_meta_->strides_[last_idx] =
extra_meta_->strides_[i + 1] * extra_meta_->sizes_[i + 1].max(1);
}
}
break;
}
case MemoryFormat::ChannelsLast: {
TORCH_CHECK(
dim() == 4, "required rank 4 tensor to use channels_last format");
set_sizes_and_strides(
sym_sizes(), get_channels_last_strides_2d(sym_sizes()));
break;
}
case MemoryFormat::ChannelsLast3d: {
TORCH_CHECK(
dim() == 5, "required rank 5 tensor to use channels_last_3d format");
set_sizes_and_strides(
sym_sizes(), get_channels_last_strides_3d(sym_sizes()));
break;
}
case MemoryFormat::Preserve:
TORCH_CHECK(false, "unsupported memory format ", memory_format);
// Cleaning warning messages, no need to break as TORCH_CHECK(false)
// terminates flow.
// break;
case MemoryFormat::NumOptions:
TORCH_INTERNAL_ASSERT(false, "invalid memory format ", memory_format);
}
// recompute contiguous flag, as currently NHWC/NCHW flags are not mutually
// exclusive see #24090
refresh_contiguous();
}
namespace impl {
namespace {

View file

@ -692,6 +692,30 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return sym_sizes();
}
template <typename T>
ArrayRef<T> generic_strides() {
return _generic_strides(identity<T>());
}
ArrayRef<int64_t> _generic_strides(identity<int64_t>) {
return strides();
}
ArrayRef<c10::SymInt> _generic_strides(identity<c10::SymInt>) {
return sym_strides();
}
template <typename T>
T generic_storage_offset() {
return _generic_storage_offset(identity<T>());
}
int64_t _generic_storage_offset(identity<int64_t>) {
return storage_offset();
}
c10::SymInt _generic_storage_offset(identity<c10::SymInt>) {
return sym_storage_offset();
}
/**
* The number of elements in a tensor.
*
@ -1604,6 +1628,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
c10::SymIntArrayRef sizes,
c10::SymIntArrayRef strides,
c10::optional<c10::SymInt> storage_offset = c10::nullopt);
// This is renamed to avoid breaking overload BC
void generic_set_sizes_contiguous(c10::SymIntArrayRef sizes);
void generic_set_sizes_contiguous(c10::IntArrayRef sizes) {
set_sizes_contiguous(sizes);
}
/**
* Change the size at some dimension. This DOES NOT update strides;
@ -2311,6 +2340,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
data_type_ = data_type;
}
void empty_tensor_restride_symint(MemoryFormat memory_format);
/**
* Set the strides of the tensor to match memory_format
*
@ -2318,9 +2349,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* memory contiguous
*/
void empty_tensor_restride(MemoryFormat memory_format) {
TORCH_CHECK(
!has_symbolic_sizes_strides_,
"empty_tensor_restride() called on tensor with symbolic shape")
if (has_symbolic_sizes_strides_) {
empty_tensor_restride_symint(memory_format);
return;
}
#ifdef DEBUG
TORCH_INTERNAL_ASSERT(
compute_numel() == numel_,

View file

@ -798,6 +798,18 @@ class TestSymbolicTracing(TestCase):
return traced_f
def test_resize_from_zero(self):
def f(x, y):
x.resize_(y.size(0))
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
self.assertExpectedInline(r, """\
def forward(self, x_1, y_1):
sym_size = torch.ops.aten.sym_size(y_1, 0); y_1 = None
resize_ = torch.ops.aten.resize_.default(x_1, [sym_size]); x_1 = sym_size = None
return None""")
def test_unary(self):
def f(x):
assert x.shape[0] < 20

View file

@ -226,7 +226,7 @@ Tensor& copy_(
const Tensor& resize_(
c10::DispatchKeySet ks,
const Tensor& self,
IntArrayRef size,
SymIntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
auto& self_ = unpack(self, "self", 0);
if (self.requires_grad()) {
@ -234,7 +234,7 @@ const Tensor& resize_(
}
{
at::AutoDispatchBelowAutograd mode;
at::redispatch::resize_(
at::redispatch::resize__symint(
ks & c10::after_autograd_keyset, self_, size, optional_memory_format);
}