Fix _make_wrapper_subclass's storage_offset handling (#68268)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68268

Previously, `_make_wrapper_subclass` ignored the storage offset it was
passed. This PR fixes that by updating TensorMaker::computeStorageSize()
and TensorMaker::make_tensor() to take into account storage_offset.

Test Plan: - added test

Reviewed By: albanD, bdhirsh

Differential Revision: D32396330

Pulled By: zou3519

fbshipit-source-id: 2c85bc4066044fe6cb5ab0fc192de6c9069855fd
This commit is contained in:
Richard Zou 2021-11-18 07:06:26 -08:00 committed by Facebook GitHub Bot
parent f0e2ad5037
commit b1aa45a8a7
2 changed files with 40 additions and 2 deletions

View file

@ -49,6 +49,9 @@ Tensor TensorMaker::make_tensor() {
} else {
tensor_impl->set_sizes_contiguous(sizes_);
}
if (storage_offset_) {
tensor_impl->set_storage_offset(*storage_offset_);
}
}
return tensor;
@ -58,14 +61,22 @@ Tensor TensorMaker::make_tensor() {
std::size_t itemsize = opts_.dtype().itemsize();
if (strides_) {
return detail::computeStorageNbytes(sizes_, *strides_, itemsize);
auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
if (storage_offset_) {
storage_size += storage_offset_.value();
}
return storage_size;
}
std::size_t size = 1;
for (std::int64_t s : sizes_) {
size *= static_cast<std::size_t>(s);
}
return size * itemsize;
auto storage_size = size * itemsize;
if (storage_offset_) {
storage_size += storage_offset_.value();
}
return storage_size;
}
inline DataPtr TensorMaker::makeDataPtrFromDeleter() const {

View file

@ -298,6 +298,33 @@ $6 = torch._ops.aten.add_($1, $5)''')
self.assertEqual(type(torch.full_like(MyTensor(2), 1.)), MyTensor)
self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)
def test_make_wrapper_subclass_propagates_metadata(self) -> None:
class WrapperTensor(torch.Tensor):
elem: torch.Tensor
__slots__ = ['elem']
@staticmethod
def __new__(cls, elem, *args, **kwargs):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, elem.size(),
dtype=elem.dtype, layout=elem.layout,
device=elem.device, requires_grad=elem.requires_grad,
strides=elem.stride(), storage_offset=elem.storage_offset())
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise RuntimeError("NYI")
# non-contiguous strides, non-zero storage offset
x = torch.randn(4, 6).t().diagonal(offset=2)
y = WrapperTensor(x)
self.assertEqual(y.size(), x.size())
self.assertEqual(y.stride(), x.stride())
self.assertEqual(y.storage_offset(), x.storage_offset())
def test_enable_python_mode_error(self) -> None:
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
with enable_python_mode(torch.Tensor):