mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
SymIntify resize_ and deduplicate memory format logic (#90442)
Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/90442 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
181d37475d
commit
e33f1eeeb7
10 changed files with 314 additions and 107 deletions
|
|
@ -106,6 +106,15 @@ size_t computeStorageNbytes(
|
|||
#endif
|
||||
}
|
||||
|
||||
SymInt computeStorageNbytesContiguous(
|
||||
SymIntArrayRef sizes,
|
||||
SymInt itemsize_bytes,
|
||||
SymInt storage_offset
|
||||
) {
|
||||
const auto numel = c10::multiply_integers(sizes);
|
||||
return itemsize_bytes * (storage_offset + numel);
|
||||
}
|
||||
|
||||
// not including mobile-only macros in this function,
|
||||
// since mobile shouldn't be using symints.
|
||||
SymInt computeStorageNbytes(
|
||||
|
|
@ -135,8 +144,9 @@ SymInt computeStorageNbytes(
|
|||
return itemsize_bytes * (storage_offset + size);
|
||||
}
|
||||
|
||||
TensorBase empty_generic(
|
||||
IntArrayRef size,
|
||||
template <typename T>
|
||||
TensorBase _empty_generic(
|
||||
ArrayRef<T> size,
|
||||
c10::Allocator* allocator,
|
||||
c10::DispatchKeySet ks,
|
||||
ScalarType scalar_type,
|
||||
|
|
@ -144,11 +154,10 @@ TensorBase empty_generic(
|
|||
at::detail::check_size_nonnegative(size);
|
||||
at::detail::raise_warning_for_complex_half(scalar_type);
|
||||
caffe2::TypeMeta dtype = scalarTypeToTypeMeta(scalar_type);
|
||||
size_t size_bytes = computeStorageNbytesContiguous(size, dtype.itemsize());
|
||||
auto size_bytes = computeStorageNbytesContiguous(size, dtype.itemsize());
|
||||
auto storage_impl = c10::make_intrusive<StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size_bytes,
|
||||
allocator->allocate(size_bytes),
|
||||
allocator,
|
||||
/*resizeable=*/true);
|
||||
|
||||
|
|
@ -156,7 +165,7 @@ TensorBase empty_generic(
|
|||
std::move(storage_impl), ks, dtype);
|
||||
// Default TensorImpl has size [0]
|
||||
if (size.size() != 1 || size[0] != 0) {
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
|
||||
tensor.unsafeGetTensorImpl()->generic_set_sizes_contiguous(size);
|
||||
}
|
||||
|
||||
if (memory_format_opt.has_value()) {
|
||||
|
|
@ -169,6 +178,15 @@ TensorBase empty_generic(
|
|||
return tensor;
|
||||
}
|
||||
|
||||
TensorBase empty_generic(
|
||||
IntArrayRef size,
|
||||
c10::Allocator* allocator,
|
||||
c10::DispatchKeySet ks,
|
||||
ScalarType scalar_type,
|
||||
c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
return _empty_generic(size, allocator, ks, scalar_type, memory_format_opt);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TensorBase _empty_strided_generic(
|
||||
T size,
|
||||
|
|
@ -338,59 +356,10 @@ TensorBase empty_symint_meta(
|
|||
c10::optional<bool> pin_memory_opt,
|
||||
c10::optional<c10::MemoryFormat> memory_format_opt
|
||||
) {
|
||||
auto device = device_or_default(device_opt);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::Meta);
|
||||
// NB: because there is no SparseMeta (yet), non-strided layout is
|
||||
// exerciseable
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
layout_or_default(layout_opt) == Layout::Strided,
|
||||
"non-strided meta tensors not supported yet"
|
||||
);
|
||||
|
||||
auto scalar_type = dtype_or_default(dtype_opt);
|
||||
auto *allocator = GetAllocator(kMeta);
|
||||
constexpr c10::DispatchKeySet meta_dks(c10::DispatchKey::Meta);
|
||||
at::detail::check_size_nonnegative(size);
|
||||
at::detail::raise_warning_for_complex_half(scalar_type);
|
||||
caffe2::TypeMeta dtype = scalarTypeToTypeMeta(scalar_type);
|
||||
SymInt size_bytes = dtype.itemsize();
|
||||
for (auto s : size) {
|
||||
size_bytes = size_bytes * s;
|
||||
}
|
||||
auto storage_impl = c10::make_intrusive<StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size_bytes,
|
||||
allocator,
|
||||
/*resizeable=*/true);
|
||||
|
||||
auto tensor = detail::make_tensor_base<TensorImpl>(
|
||||
std::move(storage_impl), meta_dks, dtype);
|
||||
|
||||
int64_t dim = size.size();
|
||||
std::vector<SymInt> strides;
|
||||
strides.resize(dim);
|
||||
|
||||
// TODO: Move this into TensorImpl
|
||||
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
|
||||
switch (memory_format) {
|
||||
case MemoryFormat::Contiguous: {
|
||||
if (dim > 0) {
|
||||
const auto last_idx = dim - 1;
|
||||
strides.at(last_idx) = 1;
|
||||
for (auto i = last_idx - 1; i >= 0; --i) {
|
||||
// TODO: max with 1
|
||||
strides.at(i) = strides.at(i+1) * size.at(i+1);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(0, "other memory format not implemented yet");
|
||||
}
|
||||
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_and_strides(size, strides);
|
||||
|
||||
return tensor;
|
||||
constexpr c10::DispatchKeySet ks(c10::DispatchKey::Meta);
|
||||
auto scalar_type = dtype_or_default(dtype_opt);
|
||||
return _empty_generic(size, allocator, ks, scalar_type, memory_format_opt);
|
||||
}
|
||||
|
||||
TensorBase empty_meta(
|
||||
|
|
|
|||
|
|
@ -20,6 +20,10 @@ TORCH_API size_t computeStorageNbytesContiguous(
|
|||
IntArrayRef sizes,
|
||||
size_t itemsize,
|
||||
size_t storage_offset = 0);
|
||||
TORCH_API SymInt computeStorageNbytesContiguous(
|
||||
SymIntArrayRef sizes,
|
||||
SymInt itemsize,
|
||||
SymInt storage_offset = 0);
|
||||
TORCH_API size_t computeStorageNbytes(
|
||||
IntArrayRef sizes,
|
||||
IntArrayRef strides,
|
||||
|
|
|
|||
|
|
@ -10,20 +10,22 @@
|
|||
#else
|
||||
#include <ATen/ops/resize_as_native.h>
|
||||
#include <ATen/ops/resize_native.h>
|
||||
#include <ATen/ops/resize.h>
|
||||
#endif
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
// Returns true if resize is necessary
|
||||
bool resize_output_check(const Tensor& output, IntArrayRef shape) {
|
||||
template <typename T>
|
||||
bool _resize_output_check(const Tensor& output, ArrayRef<T> shape) {
|
||||
// Tests for resizing of tensors with one or more elements
|
||||
if (output.sizes().equals(shape)) {
|
||||
if (at::symint::sizes<T>(output).equals(shape)) {
|
||||
return false;
|
||||
}
|
||||
if (output.numel() != 0) {
|
||||
if (at::symint::numel<T>(output) != 0) {
|
||||
TORCH_WARN(
|
||||
"An output with one or more elements was resized since it had ",
|
||||
"shape ", output.sizes(), ", which does not match the required ",
|
||||
"shape ", at::symint::sizes<T>(output), ", which does not match the required ",
|
||||
"output shape ", shape, ". ",
|
||||
"This behavior is deprecated, and in a future PyTorch release outputs ",
|
||||
"will not be resized unless they have zero elements. You can explicitly ",
|
||||
|
|
@ -33,8 +35,25 @@ bool resize_output_check(const Tensor& output, IntArrayRef shape) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool resize_output(const Tensor& output, IntArrayRef shape) {
|
||||
if (resize_output_check(output, shape)) {
|
||||
bool resize_output_check(const Tensor& output, IntArrayRef shape) {
|
||||
return _resize_output_check(output, shape);
|
||||
}
|
||||
|
||||
bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape) {
|
||||
return _resize_output_check(output, shape);
|
||||
}
|
||||
|
||||
void native_resize_(const Tensor& output, IntArrayRef shape) {
|
||||
native::resize_(output, shape);
|
||||
}
|
||||
|
||||
void native_resize_(const Tensor& output, SymIntArrayRef shape) {
|
||||
native::resize__symint(output, shape);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool _resize_output(const Tensor& output, ArrayRef<T> shape) {
|
||||
if (_resize_output_check<T>(output, shape)) {
|
||||
// avoid a redispatch for cpu and cuda.
|
||||
// TODO: when resize_cuda_ is re-written to be unified with resize_,
|
||||
// we can provide the same benefit for cuda.
|
||||
|
|
@ -42,9 +61,9 @@ bool resize_output(const Tensor& output, IntArrayRef shape) {
|
|||
// TODO(#61485): functorch wrapped tensors should not go through the
|
||||
// fast path. This is a hack, longer term solutions are in the issue
|
||||
if (output.is_cpu() && !isTensorSubclassLike(output)) {
|
||||
at::native::resize_(output, shape);
|
||||
native_resize_(output, shape);
|
||||
} else {
|
||||
output.resize_(shape);
|
||||
at::symint::resize_<T>(output, shape);
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
|
|
@ -52,6 +71,14 @@ bool resize_output(const Tensor& output, IntArrayRef shape) {
|
|||
}
|
||||
}
|
||||
|
||||
bool resize_output(const Tensor& output, IntArrayRef shape) {
|
||||
return _resize_output(output, shape);
|
||||
}
|
||||
|
||||
bool resize_output_symint(const Tensor& output, SymIntArrayRef shape) {
|
||||
return _resize_output(output, shape);
|
||||
}
|
||||
|
||||
const Tensor& _resize_output_(const Tensor& self, IntArrayRef shape, c10::Device device) {
|
||||
TORCH_CHECK(self.device() == device, "out Tensor doesn't have the correct device set");
|
||||
at::native::resize_output(self, shape);
|
||||
|
|
@ -126,16 +153,92 @@ const Tensor& resize_as_(
|
|||
return result;
|
||||
}
|
||||
|
||||
const Tensor& resize_(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
if (self.has_names()) {
|
||||
return resize_named_tensor_(self, size, optional_memory_format);
|
||||
|
||||
void resize_bytes_meta(StorageImpl* storage, c10::SymInt size_bytes) {
|
||||
TORCH_CHECK(storage->resizable(), "Trying to resize storage that is not resizable");
|
||||
storage->set_nbytes(size_bytes);
|
||||
}
|
||||
|
||||
static void maybe_resize_storage_meta(TensorImpl* self, c10::SymInt new_size_bytes) {
|
||||
// It does not make sense to try to resize a storage
|
||||
// to hold 0 elements, and this can break
|
||||
// if storage_offset is positive but
|
||||
// new_size is 0, so just bail in that case
|
||||
// (same comment is in Resize.h)
|
||||
if (self->sym_numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const Storage& storage = self->unsafe_storage();
|
||||
if (!storage) {
|
||||
TORCH_INTERNAL_ASSERT(0, "NYI, this should only be Caffe2");
|
||||
} else if (new_size_bytes > storage.nbytes()) {
|
||||
resize_bytes_meta(storage.unsafeGetStorageImpl(), new_size_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
static void _maybe_resize_storage(TensorImpl* self, int64_t new_size_bytes) {
|
||||
maybe_resize_storage_cpu(self, new_size_bytes);
|
||||
}
|
||||
|
||||
static void _maybe_resize_storage(TensorImpl* self, c10::SymInt new_size_bytes) {
|
||||
maybe_resize_storage_meta(self, new_size_bytes);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TensorImpl* _resize_impl_(
|
||||
TensorImpl* self,
|
||||
ArrayRef<T> size,
|
||||
at::OptionalArrayRef<T> stride,
|
||||
bool resize_storage) {
|
||||
if (self->generic_sizes<T>() == size && (!stride || self->generic_strides<T>() == stride.value())) {
|
||||
return self;
|
||||
}
|
||||
|
||||
const auto itemsize = self->dtype().itemsize();
|
||||
const auto storage_offset = self->generic_storage_offset<T>();
|
||||
T storage_size = T(1);
|
||||
if (stride) {
|
||||
self->set_sizes_and_strides(size, *stride);
|
||||
storage_size = at::detail::computeStorageNbytes(
|
||||
size, *stride, itemsize, storage_offset);
|
||||
} else {
|
||||
self->generic_set_sizes_contiguous(size);
|
||||
storage_size = at::detail::computeStorageNbytesContiguous(
|
||||
size, itemsize, storage_offset);
|
||||
}
|
||||
|
||||
if (resize_storage) {
|
||||
_maybe_resize_storage(self, storage_size);
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
TensorImpl* resize_impl_cpu_(
|
||||
TensorImpl* self,
|
||||
IntArrayRef size,
|
||||
at::OptionalIntArrayRef stride,
|
||||
bool resize_storage) {
|
||||
return _resize_impl_(self, size, stride, resize_storage);
|
||||
}
|
||||
|
||||
TensorImpl* resize_impl_meta_(
|
||||
TensorImpl* self,
|
||||
c10::SymIntArrayRef size,
|
||||
at::OptionalSymIntArrayRef stride,
|
||||
bool resize_storage = true) {
|
||||
return _resize_impl_(self, size, stride, resize_storage);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const Tensor& _resize_(
|
||||
const Tensor& self,
|
||||
ArrayRef<T> size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
auto* self_ = self.unsafeGetTensorImpl();
|
||||
// NOLINTNEXTLINE(bugprone-argument-comment)
|
||||
resize_impl_cpu_(self_, size, /*strides=*/c10::nullopt);
|
||||
_resize_impl_<T>(self_, size, /*strides=*/c10::nullopt, true);
|
||||
if (optional_memory_format.has_value()) {
|
||||
auto memory_format =
|
||||
optional_memory_format.value();
|
||||
|
|
@ -148,5 +251,23 @@ const Tensor& resize_(
|
|||
return self;
|
||||
}
|
||||
|
||||
const Tensor& resize_(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
if (self.has_names()) {
|
||||
return resize_named_tensor_(self, size, optional_memory_format);
|
||||
}
|
||||
return _resize_(self, size, optional_memory_format);
|
||||
}
|
||||
|
||||
const Tensor& resize__symint(
|
||||
const Tensor& self,
|
||||
c10::SymIntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
TORCH_INTERNAL_ASSERT(!self.has_names())
|
||||
return _resize_(self, size, optional_memory_format);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -23,11 +23,13 @@ namespace at { namespace native {
|
|||
// NOTE: In the future the warning will become an error
|
||||
// Returns a bool saying whether or not the resize actually happened or not
|
||||
TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
|
||||
TORCH_API bool resize_output_symint(const Tensor& output, SymIntArrayRef shape);
|
||||
|
||||
// Utility for resize_output
|
||||
// Returns a bool saying resize should happen or not and
|
||||
// raises a warning if resizing for one or more elements
|
||||
TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
|
||||
TORCH_API bool resize_output_check_symint(const Tensor& output, SymIntArrayRef shape);
|
||||
|
||||
TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
|
||||
|
||||
|
|
@ -54,34 +56,11 @@ static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_by
|
|||
}
|
||||
}
|
||||
|
||||
inline TensorImpl* resize_impl_cpu_(
|
||||
TORCH_API TensorImpl* resize_impl_cpu_(
|
||||
TensorImpl* self,
|
||||
IntArrayRef size,
|
||||
at::OptionalIntArrayRef stride,
|
||||
bool resize_storage = true) {
|
||||
if (self->sizes() == size && (!stride || self->strides() == stride.value())) {
|
||||
return self;
|
||||
}
|
||||
|
||||
const auto itemsize = self->dtype().itemsize();
|
||||
const auto storage_offset = self->storage_offset();
|
||||
size_t storage_size = 1;
|
||||
if (stride) {
|
||||
self->set_sizes_and_strides(size, *stride);
|
||||
storage_size = at::detail::computeStorageNbytes(
|
||||
size, *stride, itemsize, storage_offset);
|
||||
} else {
|
||||
self->set_sizes_contiguous(size);
|
||||
storage_size = at::detail::computeStorageNbytesContiguous(
|
||||
size, itemsize, storage_offset);
|
||||
}
|
||||
|
||||
if (resize_storage) {
|
||||
maybe_resize_storage_cpu(self, storage_size);
|
||||
}
|
||||
|
||||
return self;
|
||||
}
|
||||
bool resize_storage = true);
|
||||
|
||||
template <typename T>
|
||||
T maybe_convert_symint(c10::SymInt) = delete;
|
||||
|
|
|
|||
|
|
@ -2284,7 +2284,8 @@
|
|||
device_guard: False
|
||||
tags: inplace_view
|
||||
dispatch:
|
||||
CPU, Meta: resize_
|
||||
Meta: resize__symint
|
||||
CPU: resize_
|
||||
CUDA: resize_cuda_
|
||||
MPS: resize_mps_
|
||||
QuantizedCPU: quantized_resize_cpu_
|
||||
|
|
|
|||
|
|
@ -61,8 +61,9 @@ inline std::ostream& operator<<(
|
|||
|
||||
// Note: Hardcoded the channel last stride indices here to get better
|
||||
// performance
|
||||
inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
|
||||
std::vector<int64_t> strides(sizes.size());
|
||||
template <typename T>
|
||||
inline std::vector<T> get_channels_last_strides_2d(ArrayRef<T> sizes) {
|
||||
std::vector<T> strides(sizes.size());
|
||||
switch (sizes.size()) {
|
||||
case 4:
|
||||
strides[1] = 1;
|
||||
|
|
@ -81,8 +82,13 @@ inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
|
|||
}
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
|
||||
std::vector<int64_t> strides(sizes.size());
|
||||
inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
|
||||
return get_channels_last_strides_2d<int64_t>(sizes);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> get_channels_last_strides_3d(ArrayRef<T> sizes) {
|
||||
std::vector<T> strides(sizes.size());
|
||||
switch (sizes.size()) {
|
||||
case 5:
|
||||
strides[1] = 1;
|
||||
|
|
@ -103,6 +109,10 @@ inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
|
|||
}
|
||||
}
|
||||
|
||||
inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
|
||||
return get_channels_last_strides_3d<int64_t>(sizes);
|
||||
}
|
||||
|
||||
// NOTE:
|
||||
// Below are Helper functions for is_channels_last_strides_xd.
|
||||
// 1. Please do not combine these helper functions, each helper function handles
|
||||
|
|
|
|||
|
|
@ -993,6 +993,10 @@ void TensorImpl::set_sizes_and_strides(
|
|||
set_storage_offset(storage_offset->as_int_unchecked());
|
||||
return;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
allow_tensor_metadata_change(),
|
||||
"set_sizes_and_strides ",
|
||||
err_msg_tensor_metadata_change_not_allowed);
|
||||
|
||||
has_symbolic_sizes_strides_ = true;
|
||||
refresh_sizes_strides_policy();
|
||||
|
|
@ -1011,6 +1015,81 @@ void TensorImpl::set_sizes_and_strides(
|
|||
refresh_contiguous();
|
||||
}
|
||||
|
||||
void TensorImpl::generic_set_sizes_contiguous(SymIntArrayRef sizes) {
|
||||
auto int_sizes = asIntArrayRefSlowOpt(sizes);
|
||||
if (int_sizes.has_value()) {
|
||||
set_sizes_contiguous(*int_sizes);
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
allow_tensor_metadata_change(),
|
||||
"generic_set_sizes_contiguous ",
|
||||
err_msg_tensor_metadata_change_not_allowed);
|
||||
|
||||
has_symbolic_sizes_strides_ = true;
|
||||
refresh_sizes_strides_policy();
|
||||
if (!extra_meta_) {
|
||||
extra_meta_ = std::make_unique<ExtraMeta>();
|
||||
extra_meta_->storage_offset_ = storage_offset_;
|
||||
}
|
||||
|
||||
clone_symvec(sizes, extra_meta_->sizes_);
|
||||
refresh_numel();
|
||||
empty_tensor_restride_symint(
|
||||
MemoryFormat::Contiguous); // calls refresh_contiguous()
|
||||
}
|
||||
|
||||
void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
|
||||
TORCH_INTERNAL_ASSERT(has_symbolic_sizes_strides_);
|
||||
#ifdef DEBUG
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
compute_numel() == numel_,
|
||||
"If you are seeing this error, that means empty_tensor_restride was "
|
||||
"called before setting correct numel");
|
||||
#endif
|
||||
switch (memory_format) {
|
||||
case MemoryFormat::Contiguous: {
|
||||
// dim_ is a virtual call, don't repeat it
|
||||
const auto dim_ = dim();
|
||||
extra_meta_->strides_.resize(dim_);
|
||||
if (dim_ > 0) {
|
||||
const auto last_idx = dim_ - 1;
|
||||
extra_meta_->strides_[last_idx] = c10::SymInt(1);
|
||||
for (auto i = last_idx - 1; i >= 0; --i) {
|
||||
extra_meta_->strides_[last_idx] =
|
||||
extra_meta_->strides_[i + 1] * extra_meta_->sizes_[i + 1].max(1);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::ChannelsLast: {
|
||||
TORCH_CHECK(
|
||||
dim() == 4, "required rank 4 tensor to use channels_last format");
|
||||
set_sizes_and_strides(
|
||||
sym_sizes(), get_channels_last_strides_2d(sym_sizes()));
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::ChannelsLast3d: {
|
||||
TORCH_CHECK(
|
||||
dim() == 5, "required rank 5 tensor to use channels_last_3d format");
|
||||
set_sizes_and_strides(
|
||||
sym_sizes(), get_channels_last_strides_3d(sym_sizes()));
|
||||
break;
|
||||
}
|
||||
case MemoryFormat::Preserve:
|
||||
TORCH_CHECK(false, "unsupported memory format ", memory_format);
|
||||
// Cleaning warning messages, no need to break as TORCH_CHECK(false)
|
||||
// terminates flow.
|
||||
// break;
|
||||
case MemoryFormat::NumOptions:
|
||||
TORCH_INTERNAL_ASSERT(false, "invalid memory format ", memory_format);
|
||||
}
|
||||
// recompute contiguous flag, as currently NHWC/NCHW flags are not mutually
|
||||
// exclusive see #24090
|
||||
refresh_contiguous();
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
namespace {
|
||||
|
|
|
|||
|
|
@ -692,6 +692,30 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
return sym_sizes();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ArrayRef<T> generic_strides() {
|
||||
return _generic_strides(identity<T>());
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> _generic_strides(identity<int64_t>) {
|
||||
return strides();
|
||||
}
|
||||
ArrayRef<c10::SymInt> _generic_strides(identity<c10::SymInt>) {
|
||||
return sym_strides();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T generic_storage_offset() {
|
||||
return _generic_storage_offset(identity<T>());
|
||||
}
|
||||
|
||||
int64_t _generic_storage_offset(identity<int64_t>) {
|
||||
return storage_offset();
|
||||
}
|
||||
c10::SymInt _generic_storage_offset(identity<c10::SymInt>) {
|
||||
return sym_storage_offset();
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of elements in a tensor.
|
||||
*
|
||||
|
|
@ -1604,6 +1628,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
c10::SymIntArrayRef sizes,
|
||||
c10::SymIntArrayRef strides,
|
||||
c10::optional<c10::SymInt> storage_offset = c10::nullopt);
|
||||
// This is renamed to avoid breaking overload BC
|
||||
void generic_set_sizes_contiguous(c10::SymIntArrayRef sizes);
|
||||
void generic_set_sizes_contiguous(c10::IntArrayRef sizes) {
|
||||
set_sizes_contiguous(sizes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Change the size at some dimension. This DOES NOT update strides;
|
||||
|
|
@ -2311,6 +2340,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
data_type_ = data_type;
|
||||
}
|
||||
|
||||
void empty_tensor_restride_symint(MemoryFormat memory_format);
|
||||
|
||||
/**
|
||||
* Set the strides of the tensor to match memory_format
|
||||
*
|
||||
|
|
@ -2318,9 +2349,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
* memory contiguous
|
||||
*/
|
||||
void empty_tensor_restride(MemoryFormat memory_format) {
|
||||
TORCH_CHECK(
|
||||
!has_symbolic_sizes_strides_,
|
||||
"empty_tensor_restride() called on tensor with symbolic shape")
|
||||
if (has_symbolic_sizes_strides_) {
|
||||
empty_tensor_restride_symint(memory_format);
|
||||
return;
|
||||
}
|
||||
#ifdef DEBUG
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
compute_numel() == numel_,
|
||||
|
|
|
|||
|
|
@ -798,6 +798,18 @@ class TestSymbolicTracing(TestCase):
|
|||
return traced_f
|
||||
|
||||
|
||||
def test_resize_from_zero(self):
|
||||
def f(x, y):
|
||||
x.resize_(y.size(0))
|
||||
|
||||
r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
|
||||
self.assertExpectedInline(r, """\
|
||||
def forward(self, x_1, y_1):
|
||||
sym_size = torch.ops.aten.sym_size(y_1, 0); y_1 = None
|
||||
resize_ = torch.ops.aten.resize_.default(x_1, [sym_size]); x_1 = sym_size = None
|
||||
return None""")
|
||||
|
||||
|
||||
def test_unary(self):
|
||||
def f(x):
|
||||
assert x.shape[0] < 20
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ Tensor& copy_(
|
|||
const Tensor& resize_(
|
||||
c10::DispatchKeySet ks,
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
SymIntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
auto& self_ = unpack(self, "self", 0);
|
||||
if (self.requires_grad()) {
|
||||
|
|
@ -234,7 +234,7 @@ const Tensor& resize_(
|
|||
}
|
||||
{
|
||||
at::AutoDispatchBelowAutograd mode;
|
||||
at::redispatch::resize_(
|
||||
at::redispatch::resize__symint(
|
||||
ks & c10::after_autograd_keyset, self_, size, optional_memory_format);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue