diff --git a/aten/src/ATen/templates/Functions.cpp b/aten/src/ATen/templates/Functions.cpp index c98bd80a792..baf835eb87b 100644 --- a/aten/src/ATen/templates/Functions.cpp +++ b/aten/src/ATen/templates/Functions.cpp @@ -49,6 +49,9 @@ Tensor TensorMaker::make_tensor() { } else { tensor_impl->set_sizes_contiguous(sizes_); } + if (storage_offset_) { + tensor_impl->set_storage_offset(*storage_offset_); + } } return tensor; @@ -58,14 +61,22 @@ Tensor TensorMaker::make_tensor() { std::size_t itemsize = opts_.dtype().itemsize(); if (strides_) { - return detail::computeStorageNbytes(sizes_, *strides_, itemsize); + auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize); + if (storage_offset_) { + storage_size += storage_offset_.value(); + } + return storage_size; } std::size_t size = 1; for (std::int64_t s : sizes_) { size *= static_cast(s); } - return size * itemsize; + auto storage_size = size * itemsize; + if (storage_offset_) { + storage_size += storage_offset_.value(); + } + return storage_size; } inline DataPtr TensorMaker::makeDataPtrFromDeleter() const { diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 506739c3a0d..f40fd53a74f 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -298,6 +298,33 @@ $6 = torch._ops.aten.add_($1, $5)''') self.assertEqual(type(torch.full_like(MyTensor(2), 1.)), MyTensor) self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor) + def test_make_wrapper_subclass_propagates_metadata(self) -> None: + class WrapperTensor(torch.Tensor): + elem: torch.Tensor + + __slots__ = ['elem'] + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, elem.size(), + dtype=elem.dtype, layout=elem.layout, + device=elem.device, requires_grad=elem.requires_grad, + strides=elem.stride(), storage_offset=elem.storage_offset()) + r.elem = elem + return r + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + raise RuntimeError("NYI") + + # non-contiguous strides, non-zero storage offset + x = torch.randn(4, 6).t().diagonal(offset=2) + y = WrapperTensor(x) + self.assertEqual(y.size(), x.size()) + self.assertEqual(y.stride(), x.stride()) + self.assertEqual(y.storage_offset(), x.storage_offset()) + def test_enable_python_mode_error(self) -> None: with self.assertRaisesRegex(ValueError, "__torch_dispatch__"): with enable_python_mode(torch.Tensor):