From b1aa45a8a78211fcfec0923b5facca2b2829311d Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Thu, 18 Nov 2021 07:06:26 -0800 Subject: [PATCH] Fix `_make_wrapper_subclass`'s storage_offset handling (#68268) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68268 Previously, `_make_wrapper_subclass` ignored the storage offset it was passed. This PR fixes that by updating TensorMaker::computeStorageSize() and TensorMaker::make_tensor() to take into account storage_offset. Test Plan: - added test Reviewed By: albanD, bdhirsh Differential Revision: D32396330 Pulled By: zou3519 fbshipit-source-id: 2c85bc4066044fe6cb5ab0fc192de6c9069855fd --- aten/src/ATen/templates/Functions.cpp | 15 +++++++++++++-- test/test_python_dispatch.py | 27 +++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/templates/Functions.cpp b/aten/src/ATen/templates/Functions.cpp index c98bd80a792..baf835eb87b 100644 --- a/aten/src/ATen/templates/Functions.cpp +++ b/aten/src/ATen/templates/Functions.cpp @@ -49,6 +49,9 @@ Tensor TensorMaker::make_tensor() { } else { tensor_impl->set_sizes_contiguous(sizes_); } + if (storage_offset_) { + tensor_impl->set_storage_offset(*storage_offset_); + } } return tensor; @@ -58,14 +61,22 @@ Tensor TensorMaker::make_tensor() { std::size_t itemsize = opts_.dtype().itemsize(); if (strides_) { - return detail::computeStorageNbytes(sizes_, *strides_, itemsize); + auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize); + if (storage_offset_) { + storage_size += storage_offset_.value(); + } + return storage_size; } std::size_t size = 1; for (std::int64_t s : sizes_) { size *= static_cast(s); } - return size * itemsize; + auto storage_size = size * itemsize; + if (storage_offset_) { + storage_size += storage_offset_.value(); + } + return storage_size; } inline DataPtr TensorMaker::makeDataPtrFromDeleter() const { diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 506739c3a0d..f40fd53a74f 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -298,6 +298,33 @@ $6 = torch._ops.aten.add_($1, $5)''') self.assertEqual(type(torch.full_like(MyTensor(2), 1.)), MyTensor) self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor) + def test_make_wrapper_subclass_propagates_metadata(self) -> None: + class WrapperTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, elem.size(), + dtype=elem.dtype, layout=elem.layout, + device=elem.device, requires_grad=elem.requires_grad, + strides=elem.stride(), storage_offset=elem.storage_offset()) + r.elem = elem + return r + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + raise RuntimeError("NYI") + + # non-contiguous strides, non-zero storage offset + x = torch.randn(4, 6).t().diagonal(offset=2) + y = WrapperTensor(x) + self.assertEqual(y.size(), x.size()) + self.assertEqual(y.stride(), x.stride()) + self.assertEqual(y.storage_offset(), x.storage_offset()) + def test_enable_python_mode_error(self) -> None: with self.assertRaisesRegex(ValueError, "__torch_dispatch__"): with enable_python_mode(torch.Tensor):