From c6e3adaf543c8e241b34f2607c701ceec6db89f5 Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Mon, 28 Aug 2023 19:43:09 -0700 Subject: [PATCH] add dynamic shapes support for subclasses that override size/stride (#107916) This is mostly a minor fix on top of @soulitzer's PR https://github.com/pytorch/pytorch/pull/107839. (1) `strides` wasn't going through the new `set_tensor_attr_with_capsule` flow (2) The dynamic shapes overload for `_make_wrapper_subclass` currently errors when you try to use custom sizes - I removed the error (3) added a test I need this later because I'm adding a `__torch_dispatch__` `FunctionalTensor` wrapper subclass, that needs to support dynamic shapes, and also plumb metadata calls to its inner tensor later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107916 Approved by: https://github.com/ezyang, https://github.com/soulitzer ghstack dependencies: #107915 --- test/test_python_dispatch.py | 47 ++++++++++++++++++++++ torch/csrc/PyInterpreter.cpp | 53 +++++++++++-------------- torch/csrc/autograd/python_variable.cpp | 5 +-- 3 files changed, 72 insertions(+), 33 deletions(-) diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index ffd5ad06eea..c2896904ed3 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -2029,6 +2029,53 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass) self.assertEqual(e.size(), (4, 2)) + def test_custom_size_policy_dynamic_shapes(self): + data = torch.randn(6, 2) + + class CustomSizeDynamicShapesTensor(torch.Tensor): + @staticmethod + def __new__(cls, inner): + return torch.Tensor._make_wrapper_subclass( + # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. + # Calling the overload that has kwargs causes us to go down the first overload path, + # which will **always** specialize sizes. + # We should probably eventually fix this so that the first overload can just handle dynamic shapes. + cls, + inner.size(), + inner.stride(), + None, + None, + inner.dtype, + inner.layout, + inner.device, + False, + inner.requires_grad, + "sizes", + ) + + def __init__(self, inner): + self.inner = inner + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func == torch.ops.aten.sym_size.default: + return args[0].inner.shape + if func == torch.ops.aten.sym_stride.default: + return args[0].inner.shape + return NotImplemented + + x = torch.ones(2, 2) + + def trace_fn(x): + x_wrapper = CustomSizeDynamicShapesTensor(x) + return x_wrapper.size(), x_wrapper.stride() + fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x) + self.assertExpectedInline(fx_g.code.strip(), """\ +def forward(self, x_1): + sym_size = torch.ops.aten.sym_size(x_1, 0) + sym_size_1 = torch.ops.aten.sym_size(x_1, 1); x_1 = None + return ((sym_size, sym_size_1), (sym_size, sym_size_1))""") + def test_data_ptr_respects_numel_slow_path(self): data = torch.randn(6, 2) diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index b5c84d5f6cb..82af9ad87d6 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -544,6 +544,17 @@ c10::Device ConcretePyInterpreterVTable::device( return toDevice(out.ptr()); } +static void set_tensor_attr_with_capsule( + const c10::TensorImpl* tensor, + py::capsule& capsule, + const char* attr_name) { + c10::optional mb_obj = + tensor->pyobj_slot()->check_pyobj(getPyInterpreter()); + TORCH_CHECK( + mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); + py::handle(mb_obj.value()).attr(attr_name) = capsule; +} + c10::IntArrayRef ConcretePyInterpreterVTable::strides( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; @@ -566,38 +577,20 @@ c10::IntArrayRef ConcretePyInterpreterVTable::strides( "Cannot call strides on a tensor with symbolic shapes/strides"); return self->strides_default(); } - - py::object values = py::reinterpret_steal(out.ptr()); - - c10::optional mb_obj = - self->pyobj_slot()->check_pyobj(getPyInterpreter()); TORCH_CHECK( - mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); - PyObject* subclass = *mb_obj; - Py_INCREF(subclass); - py::object sub = py::reinterpret_steal(subclass); + py::isinstance(out) || py::isinstance(out), + "strides must be a list or a tuple"); - py::object os = py::module_::import("torch").attr("overrides"); - py::function get_buffer = - py::reinterpret_borrow(os.attr("get_buffer")); - auto buffer = get_buffer(sub, values, "stride"); - auto result = THPUtils_unpackLongs(buffer.ptr()); - // NOLINTNEXTLINE(performance-no-int-to-ptr) - int64_t* start = reinterpret_cast(result[0]); - int64_t len = result[1]; - - return c10::IntArrayRef(start, len); -} - -static void set_tensor_attr_with_capsule( - const c10::TensorImpl* tensor, - py::capsule& capsule, - const char* attr_name) { - c10::optional mb_obj = - tensor->pyobj_slot()->check_pyobj(getPyInterpreter()); - TORCH_CHECK( - mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); - py::handle(mb_obj.value()).attr(attr_name) = capsule; + size_t len = py::len(out); + int64_t* ptr = new int64_t[len]; + auto capsule = + py::capsule(ptr, [](void* p) { delete[] reinterpret_cast(p); }); + int64_t idx = 0; + for (auto it = out.begin(); it != out.end(); ++it, ++idx) { + ptr[idx] = py::cast(*it); + } + set_tensor_attr_with_capsule(self, capsule, "_sizes_capsule"); + return c10::IntArrayRef(ptr, len); } c10::IntArrayRef ConcretePyInterpreterVTable::sizes( diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index bb91716e3a6..ea3110160ca 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -720,9 +720,8 @@ static PyObject* THPVariable_make_wrapper_subclass( const auto sizes_strides_policy = r.stringViewOptional(10); if (sizes_strides_policy.has_value()) { - TORCH_CHECK( - false, - "Setting sizes_strides_policy isn't supported for this overload") + tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides( + parseSizesStridesPolicyArgument(*sizes_strides_policy)); } }