mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix _make_wrapper_subclass's storage_offset handling (#68268)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68268 Previously, `_make_wrapper_subclass` ignored the storage offset it was passed. This PR fixes that by updating TensorMaker::computeStorageSize() and TensorMaker::make_tensor() to take into account storage_offset. Test Plan: - added test Reviewed By: albanD, bdhirsh Differential Revision: D32396330 Pulled By: zou3519 fbshipit-source-id: 2c85bc4066044fe6cb5ab0fc192de6c9069855fd
This commit is contained in:
parent
f0e2ad5037
commit
b1aa45a8a7
2 changed files with 40 additions and 2 deletions
|
|
@ -49,6 +49,9 @@ Tensor TensorMaker::make_tensor() {
|
|||
} else {
|
||||
tensor_impl->set_sizes_contiguous(sizes_);
|
||||
}
|
||||
if (storage_offset_) {
|
||||
tensor_impl->set_storage_offset(*storage_offset_);
|
||||
}
|
||||
}
|
||||
|
||||
return tensor;
|
||||
|
|
@ -58,14 +61,22 @@ Tensor TensorMaker::make_tensor() {
|
|||
std::size_t itemsize = opts_.dtype().itemsize();
|
||||
|
||||
if (strides_) {
|
||||
return detail::computeStorageNbytes(sizes_, *strides_, itemsize);
|
||||
auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
|
||||
if (storage_offset_) {
|
||||
storage_size += storage_offset_.value();
|
||||
}
|
||||
return storage_size;
|
||||
}
|
||||
|
||||
std::size_t size = 1;
|
||||
for (std::int64_t s : sizes_) {
|
||||
size *= static_cast<std::size_t>(s);
|
||||
}
|
||||
return size * itemsize;
|
||||
auto storage_size = size * itemsize;
|
||||
if (storage_offset_) {
|
||||
storage_size += storage_offset_.value();
|
||||
}
|
||||
return storage_size;
|
||||
}
|
||||
|
||||
inline DataPtr TensorMaker::makeDataPtrFromDeleter() const {
|
||||
|
|
|
|||
|
|
@ -298,6 +298,33 @@ $6 = torch._ops.aten.add_($1, $5)''')
|
|||
self.assertEqual(type(torch.full_like(MyTensor(2), 1.)), MyTensor)
|
||||
self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)
|
||||
|
||||
def test_make_wrapper_subclass_propagates_metadata(self) -> None:
|
||||
class WrapperTensor(torch.Tensor):
|
||||
elem: torch.Tensor
|
||||
|
||||
__slots__ = ['elem']
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, elem, *args, **kwargs):
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
cls, elem.size(),
|
||||
dtype=elem.dtype, layout=elem.layout,
|
||||
device=elem.device, requires_grad=elem.requires_grad,
|
||||
strides=elem.stride(), storage_offset=elem.storage_offset())
|
||||
r.elem = elem
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
raise RuntimeError("NYI")
|
||||
|
||||
# non-contiguous strides, non-zero storage offset
|
||||
x = torch.randn(4, 6).t().diagonal(offset=2)
|
||||
y = WrapperTensor(x)
|
||||
self.assertEqual(y.size(), x.size())
|
||||
self.assertEqual(y.stride(), x.stride())
|
||||
self.assertEqual(y.storage_offset(), x.storage_offset())
|
||||
|
||||
def test_enable_python_mode_error(self) -> None:
|
||||
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
|
||||
with enable_python_mode(torch.Tensor):
|
||||
|
|
|
|||
Loading…
Reference in a new issue