mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Removed overhead from reshape() call if tensor doesn't need to be changed (#61466)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61466 ## Goal Per #55126 the performance of `reshape` is worse than `alias` in cases where they are performing the same operation (i.e. where reshape is returning a view) because `reshape` delegates to `view` and duplicates some of the operations (specifically `infer_size_dv` and `computeStride`). The goal of this pull-request is to reduce or remove the additional overhead that `reshape` has. ### Proposed Implementation Instead of using `view` we implement a private/internal operator (`_reshape_alias`) that `reshape` dispatches to which skips the relevant checks. This is functionally equivalent to `as_strided` however it is a lot simpler because it's specialized to this use-case, and importantly the `backward` implementation is a lot faster. Note that we have to dispatch (`reshape` is a composite operator) because `reshape` can return either a view or a copy of the Tensor depending on the parameters, and this complicates implementing a derivative/backward for `reshape`. ### Why not `as_strided`? Using `as_strided` directly slows down autograd. If we use a custom function equivalent to `_reshape_alias` but with a simpler backward function then `view` has the same performance as `reshape`. If we delegate to `as_strided` it is about 56% slower (and this holds against our custom function). This is also the reason we make an internal operator named `_reshape_alias` instead of exposing a new operator since this should only be used in the `reshape` case and it is effectively a more limited version of `view`, `alias`, and `as_strided`. ## Benchmarks In a micro-benchmark for `backward` running: ```cpp // Setup at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); // Benchmark loop // `reshape(-1)` replaced with a call to view(-1) for view baseline x.pow(4).reshape(-1).mean().backward(); ``` I also benchmarked simple operations without gradients using: ```cpp // Setup at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); // Benchmark loop x.reshape(-1) // replaced with a call to view(-1) for view baseline ``` Baselined to `view`: * Original `reshape`: `+3.3%` (without gradients `+20.8%`) * Using `as_strided`: `+55.1%` (without gradients `+1.0%`) * Using custom `_reshape_view`: `-1.0%` (without gradients `+6.2%`) In absolute terms (note the percentages above were generated comparing between runs/tests rather than to a single baseline): * Original `view`: `53.66 us` (without gradients `582.78 ns`) * Original `reshape`: `55.46 us` (without gradients `704.24 ns`) * Using `as_strided`: `83.24 us` (without gradients `576.49 ns`) * Using custom `_reshape_view`: `53.13 us` (without gradients `536.01 ns`) Note that these benchmarks perform a backwards operation as well. When compared without using gradient computation at all the performance differneces are more pronounced as this takes up more of the time. ### Original performance <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f0e4d393160> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.66 us IQR: 2.70 us (52.54 to 55.24) 884 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f0e2ebd4fa0> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 55.46 us IQR: 2.61 us (54.39 to 57.01) 889 measurements, 100 runs per measurement, 1 thread] 2276116 2286256 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f0e5b2e3e20> 2640 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) 1920 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1040 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long>&&) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) 720 ???:__tls_get_addr 520 ???:at::shouldRunRecordFunction(bool*) 520 ???:__memcpy_avx_unaligned_erms 200 ???:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10:: ... g>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) 100 ???:c10::TensorImpl::strides() const 100 ???:c10::TensorImpl::sizes() const 100 ???:at::(anonymous namespace)::manager() 77 /tmp/benchmark_utils_jit_build__1626465284__8a34e7ff-cd37-4a82-be28-7f19e081e771/timer_cpp_7815557938202456331/timer_src.cpp:main 40 ???:c10::TensorImpl::numel() const -77 /tmp/benchmark_utils_jit_build__1626465284__8a34e7ff-cd37-4a82-be28-7f19e081e771/timer_cpp_8055217880649990171/timer_src.cpp:main -260 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 10140 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f850dd66c10> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 582.78 ns IQR: 33.80 ns (573.80 to 607.61) 833 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f850de31e20> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 704.24 ns IQR: 24.42 ns (697.20 to 721.62) 679 measurements, 10000 runs per measurement, 1 thread] 56896 67036 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f84e1930bb0> 2640 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) 1920 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1040 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long>&&) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) 720 ???:__tls_get_addr 520 ???:at::shouldRunRecordFunction(bool*) 520 ???:__memcpy_avx_unaligned_erms 200 ???:c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10:: ... g>)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) 100 ???:c10::TensorImpl::strides() const 100 ???:c10::TensorImpl::sizes() const 100 ???:at::(anonymous namespace)::manager() 76 /tmp/benchmark_utils_jit_build__1626466038__15fbbac0-2072-4459-8f8e-08121a905b99/timer_cpp_547407365342278353/timer_src.cpp:main 40 ???:c10::TensorImpl::numel() const -76 /tmp/benchmark_utils_jit_build__1626466038__15fbbac0-2072-4459-8f8e-08121a905b99/timer_cpp_3457873755756181226/timer_src.cpp:main -260 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 10140 ``` </details> ### Using `as_strided` <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f8b13bb5b50> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.37 us IQR: 3.15 us (51.73 to 54.88) 936 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f8af55f8490> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 83.24 us IQR: 4.05 us (81.20 to 85.25) 609 measurements, 100 runs per measurement, 1 thread] 2267916 2525061 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f8af55f8e50> 31930 ???:_int_free 15940 ???:malloc 11595 ???:_int_malloc 10100 ???:torch::autograd::generated::details::as_strided_backward(at::Tensor, at::TensorGeometry, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 9360 ???:__tls_get_addr 8280 ???:free 8100 ???:torch::autograd::VariableType::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 4520 ???:c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::reset_() 4080 ???:operator new(unsigned long) ... -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1220 ???:torch::autograd::generated::ViewBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -2560 ???:at::detail::computeStride(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::SmallVector<long, 5u> const&) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) -4860 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) Total: 257145 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f93176a0160> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 570.55 ns IQR: 32.69 ns (552.87 to 585.56) 874 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f92f8f29490> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 576.49 ns IQR: 37.95 ns (559.51 to 597.46) 861 measurements, 10000 runs per measurement, 1 thread] 56896 58556 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f932556ca60> 2140 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1940 ???:torch::autograd::VariableType::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1880 ???:torch::ADInplaceOrView::(anonymous namespace)::as_strided(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1720 ???:at::_ops::as_strided::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1400 ???:at::native::as_strided_tensorimpl(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 1260 ???:at::_ops::as_strided::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>)'2 1260 ???:at::_ops::as_strided::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::optional<long>) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) ... -620 ???:at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, c10::ArrayRef<long ... ::ArrayRef<long>)> const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) const -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 1660 ``` </details> ### Using custom function (`_reshape_alias`) <details> <summary>Benchmark results</summary> ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f16861d6b50> x.pow(4).view(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.50 us IQR: 2.64 us (52.32 to 54.96) 906 measurements, 100 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f1667b2ed60> x.pow(4).reshape(-1).mean().backward(); setup: at::Tensor x=torch::empty({2,2}, torch::requires_grad(true)); Median: 53.13 us IQR: 3.40 us (51.72 to 55.13) 914 measurements, 100 runs per measurement, 1 thread] 2269736 2273236 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f1693f8dc10> 5060 ???:torch::autograd::VariableType::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1780 ???:torch::ADInplaceOrView::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1660 ???:at::_ops::_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)'2 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1220 ???:torch::autograd::generated::AliasToShapeBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) ... -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1220 ???:torch::autograd::generated::ViewBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) -4860 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) Total: 3500 ``` ``` [<torch.utils.benchmark.utils.common.Measurement object at 0x7f5287adfb20> x.view(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 505.10 ns IQR: 20.04 ns (500.41 to 520.45) 944 measurements, 10000 runs per measurement, 1 thread] [<torch.utils.benchmark.utils.common.Measurement object at 0x7f526951b430> x.reshape(-1); setup: at::Tensor x=torch::empty({2,2}); Median: 536.01 ns IQR: 17.81 ns (531.34 to 549.16) 916 measurements, 10000 runs per measurement, 1 thread] 56896 60376 <torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f5295896c10> 2000 ???:at::native::reshape(at::Tensor const&, c10::ArrayRef<long>) 1860 ???:torch::autograd::VariableType::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1780 ???:torch::ADInplaceOrView::(anonymous namespace)::_reshape_alias(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1660 ???:at::_ops::_reshape_alias::call(at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 1600 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::ArrayRef<long> >(at::Tensor const&, c10::ArrayRef<long> const&, c10::ArrayRef<long> const&) 1520 ???:at::_ops::reshape::call(at::Tensor const&, c10::ArrayRef<long>) 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>)'2 1240 ???:at::_ops::_reshape_alias::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>, c10::ArrayRef<long>) 980 ???:void at::infer_size_impl<c10::SmallVector<long, 5u> >(c10::ArrayRef<long>, long, c10::SmallVector<long, 5u>&) ... -620 ???:at::Tensor c10::Dispatcher::redispatch<at::Tensor, at::Tensor const&, c10::ArrayRef<long ... ::ArrayRef<long>)> const&, c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) const -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>)'2 -780 ???:at::_ops::view::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -920 ???:c10::SmallVectorImpl<long>::operator=(c10::SmallVectorImpl<long> const&) -1520 ???:at::_ops::view::call(at::Tensor const&, c10::ArrayRef<long>) -1580 ???:torch::ADInplaceOrView::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -1680 ???:at::Tensor at::native::alias_with_sizes_and_strides<c10::SmallVector<long, 5u> >(at::Tensor const&, c10::SmallVector<long, 5u> const&, c10::SmallVector<long, 5u> const&) -1740 ???:torch::autograd::VariableType::(anonymous namespace)::view(c10::DispatchKeySet, at::Tensor const&, c10::ArrayRef<long>) -2640 ???:at::native::view(at::Tensor const&, c10::ArrayRef<long>) Total: 3480 ``` </details> Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D29792126 Pulled By: laurencer fbshipit-source-id: f0519b45b65f868aa3e8651679354558bd761dfd
This commit is contained in:
parent
a8d99a28d7
commit
adb73d3dcf
8 changed files with 95 additions and 14 deletions
|
|
@ -1040,21 +1040,43 @@ Tensor reshape(const Tensor& self, IntArrayRef proposed_shape) {
|
|||
return at::_mkldnn_reshape(self, shape);
|
||||
}
|
||||
|
||||
auto stride =
|
||||
at::detail::computeStride(self.sizes(), self.strides(), shape);
|
||||
// `computeStride` returns the proper strides to use if this
|
||||
// `reshape` can be just a view.
|
||||
//
|
||||
// NB: Even though we have viewable geometry and the target strides here,
|
||||
// we do not just call `as_strided` on `self` because the backward
|
||||
// for `as_strided` is not as efficient as that of `view` (since the
|
||||
// former is meant to handle general cases).
|
||||
// `computeStride` returns the proper strides to use if this
|
||||
// `reshape` can be just a view.
|
||||
auto stride = at::detail::computeStride(self.sizes(), self.strides(), shape);
|
||||
|
||||
// NB: Even though we have viewable geometry and the target strides here,
|
||||
// we do not just call `as_strided` on `self` because the backward
|
||||
// for `as_strided` is not as efficient as that of `view` (since the
|
||||
// former is meant to handle general cases).
|
||||
//
|
||||
// Similarly we don't call `view` because it duplicates some of the work
|
||||
// we've already done, and instead call our internal/private operator
|
||||
// `_reshape_alias` that essentially does the same thing as `view` and
|
||||
// `as_strided` without any of the extra overhead.
|
||||
if (stride.has_value()) {
|
||||
return self.view(shape);
|
||||
// Temporary check to revert to the old behavior/view in cases where the
|
||||
// device is not supported (e.g. for XLA the operation is not supported
|
||||
// so we use `view` instead).
|
||||
//
|
||||
// We need to do the checks here instead of in `native_functions.yaml`
|
||||
// to preserve backwards compatibility.
|
||||
if (! self.is_xla()) {
|
||||
return self._reshape_alias(shape, stride.value());
|
||||
} else {
|
||||
return self.view(shape);
|
||||
}
|
||||
}
|
||||
return at::_unsafe_view(self.clone(at::MemoryFormat::Contiguous), shape);
|
||||
}
|
||||
|
||||
Tensor _reshape_alias(const Tensor& self, IntArrayRef sizes, IntArrayRef strides) {
|
||||
// This is only used by `reshape` in cases where it would otherwise have dispatched
|
||||
// to `view`. This removes the overhead of calling `view` which duplicates some of
|
||||
// the work that's already been done (`infer_size_dv` and `computeStride`).
|
||||
|
||||
return alias_with_sizes_and_strides(self, sizes, strides);
|
||||
}
|
||||
|
||||
Tensor reshape_as(const Tensor& self, const Tensor& other) {
|
||||
return self.reshape(other.sizes());
|
||||
}
|
||||
|
|
@ -2152,11 +2174,13 @@ Tensor numpy_T(const Tensor &self) {
|
|||
return self.permute(transpose_dims);
|
||||
}
|
||||
|
||||
Tensor view(const Tensor& self, IntArrayRef size) {
|
||||
Tensor view(const Tensor& self,
|
||||
IntArrayRef size) {
|
||||
|
||||
at::DimVector inferred_size = at::infer_size_dv(size, self.numel());
|
||||
auto stride = at::detail::computeStride(self.sizes(),
|
||||
self.strides(),
|
||||
inferred_size);
|
||||
self.strides(),
|
||||
inferred_size);
|
||||
TORCH_CHECK(stride.has_value(), "view size is "
|
||||
"not compatible with input tensor's size and stride (at least one dimension"
|
||||
" spans across two contiguous subspaces). Use .reshape(...) instead.");
|
||||
|
|
|
|||
|
|
@ -3443,6 +3443,17 @@
|
|||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
||||
# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape.
|
||||
# They are not user-facing, hence the leading underscore. Please don't use it
|
||||
# anywhere else.
|
||||
- func: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a)
|
||||
variants: function, method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA: _reshape_alias
|
||||
# We don't need to support mkldnn since this is handled explicitly by the reshape operator.
|
||||
|
||||
- func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
|
|
|||
|
|
@ -1117,3 +1117,28 @@ TEST(TensorTest, StdDimension) {
|
|||
ASSERT_EQ(torch::std(x, 0, /*unbiased=*/true).numel(), 3);
|
||||
ASSERT_EQ(std::get<0>(torch::std_mean(x, 0, /*unbiased=*/true)).numel(), 3);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(TensorTest, ReshapeAlias) {
|
||||
// Tests the behavior of the _reshape_alias private operator so
|
||||
// that it matches the behavior of as_strided and view.
|
||||
auto x = torch::randn({3, 3});
|
||||
ASSERT_TRUE(torch::equal(
|
||||
torch::_reshape_alias(x, {2, 2}, {1, 2}),
|
||||
torch::as_strided(x, {2, 2}, {1, 2})
|
||||
));
|
||||
ASSERT_TRUE(torch::equal(
|
||||
torch::_reshape_alias(x, {9}, {1}),
|
||||
x.view({-1})
|
||||
));
|
||||
|
||||
// Test that the backward works fine.
|
||||
auto y = torch::randn({3, 3}, torch::requires_grad(true));
|
||||
auto z = torch::clone(y).detach().requires_grad_(true);
|
||||
(y * y).view({-1}).mean().backward();
|
||||
torch::_reshape_alias((z * z), {9}, {1}).mean().backward();
|
||||
ASSERT_TRUE(torch::equal(
|
||||
y.grad(),
|
||||
z.grad()
|
||||
));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1355,6 +1355,22 @@ class TestOldViewOps(TestCase):
|
|||
self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1))
|
||||
self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1))
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_reshape_view_semantics(self, device, dtype):
|
||||
tensor = make_tensor((15, 4), device, dtype)
|
||||
target = (20, 3)
|
||||
|
||||
# Cases where the tensor can be returned as a view.
|
||||
view_tensor = tensor.reshape(target)
|
||||
self.assertEqual((view_tensor.size()), target)
|
||||
self.assertEqual(tensor.storage().data_ptr(), view_tensor.storage().data_ptr())
|
||||
|
||||
# Cases where the tensor must be copied (transpose makes it non-contiguous forcing
|
||||
# the copy).
|
||||
copy_tensor = tensor.transpose(0, 1).reshape(target)
|
||||
self.assertEqual(copy_tensor.size(), target)
|
||||
self.assertNotEqual(tensor.storage().data_ptr(), copy_tensor.storage().data_ptr())
|
||||
|
||||
def test_contiguous(self, device):
|
||||
x = torch.randn(1, 16, 5, 5, device=device)
|
||||
self.assertTrue(x.is_contiguous())
|
||||
|
|
|
|||
|
|
@ -1144,6 +1144,9 @@
|
|||
# making it impossible (hard) to detect when it is actually a view.
|
||||
# - name: reshape(Tensor self, IntArrayRef shape)
|
||||
|
||||
- name: _reshape_alias(Tensor(a) self, int[] size, int[] stride) -> Tensor(a)
|
||||
self: grad.reshape(self.sizes())
|
||||
|
||||
- name: round(Tensor self) -> Tensor
|
||||
self: zeros_like(grad)
|
||||
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ VIEW_FUNCTIONS = {
|
|||
# discrete anyways.
|
||||
# FIXME: clone indices on construction.
|
||||
'sparse_coo_tensor_with_dims_and_tensors': 'values',
|
||||
'_reshape_alias': 'self',
|
||||
}
|
||||
|
||||
for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE:
|
||||
|
|
|
|||
|
|
@ -92,6 +92,7 @@ _SKIP_PYTHON_BINDINGS = [
|
|||
'data', 'is_leaf', 'output_nr', '_version', 'requires_grad_', 'retains_grad', 'set_',
|
||||
'_fw_primal', 'fake_quantize_per_tensor_affine_cachemask',
|
||||
'fake_quantize_per_channel_affine_cachemask',
|
||||
'_reshape_alias',
|
||||
]
|
||||
|
||||
SKIP_PYTHON_BINDINGS = list(map(lambda pattern: re.compile(rf'^{pattern}$'), _SKIP_PYTHON_BINDINGS))
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
|||
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
|
||||
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
|
||||
'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical',
|
||||
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'conj_physical_', '_neg_view'
|
||||
'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'conj_physical_', '_neg_view', '_reshape_alias',
|
||||
}
|
||||
|
||||
GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
|
||||
|
|
|
|||
Loading…
Reference in a new issue