add dynamic shapes support for subclasses that override size/stride (#107916)

This is mostly a minor fix on top of @soulitzer's PR https://github.com/pytorch/pytorch/pull/107839.

(1) `strides` wasn't going through the new `set_tensor_attr_with_capsule` flow
(2) The dynamic shapes overload for `_make_wrapper_subclass` currently errors when you try to use custom sizes - I removed the error
(3) added a test

I need this later because I'm adding a `__torch_dispatch__` `FunctionalTensor` wrapper subclass, that needs to support dynamic shapes, and also plumb metadata calls to its inner tensor later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107916
Approved by: https://github.com/ezyang, https://github.com/soulitzer
ghstack dependencies: #107915
This commit is contained in:
Brian Hirsh 2023-08-28 19:43:09 -07:00 committed by PyTorch MergeBot
parent 4f34caf164
commit c6e3adaf54
3 changed files with 72 additions and 33 deletions

View file

@ -2029,6 +2029,53 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
self.assertEqual(e.size(), (4, 2))
def test_custom_size_policy_dynamic_shapes(self):
data = torch.randn(6, 2)
class CustomSizeDynamicShapesTensor(torch.Tensor):
@staticmethod
def __new__(cls, inner):
return torch.Tensor._make_wrapper_subclass(
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
# Calling the overload that has kwargs causes us to go down the first overload path,
# which will **always** specialize sizes.
# We should probably eventually fix this so that the first overload can just handle dynamic shapes.
cls,
inner.size(),
inner.stride(),
None,
None,
inner.dtype,
inner.layout,
inner.device,
False,
inner.requires_grad,
"sizes",
)
def __init__(self, inner):
self.inner = inner
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func == torch.ops.aten.sym_size.default:
return args[0].inner.shape
if func == torch.ops.aten.sym_stride.default:
return args[0].inner.shape
return NotImplemented
x = torch.ones(2, 2)
def trace_fn(x):
x_wrapper = CustomSizeDynamicShapesTensor(x)
return x_wrapper.size(), x_wrapper.stride()
fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x)
self.assertExpectedInline(fx_g.code.strip(), """\
def forward(self, x_1):
sym_size = torch.ops.aten.sym_size(x_1, 0)
sym_size_1 = torch.ops.aten.sym_size(x_1, 1); x_1 = None
return ((sym_size, sym_size_1), (sym_size, sym_size_1))""")
def test_data_ptr_respects_numel_slow_path(self):
data = torch.randn(6, 2)

View file

@ -544,6 +544,17 @@ c10::Device ConcretePyInterpreterVTable::device(
return toDevice(out.ptr());
}
static void set_tensor_attr_with_capsule(
const c10::TensorImpl* tensor,
py::capsule& capsule,
const char* attr_name) {
c10::optional<PyObject*> mb_obj =
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
py::handle(mb_obj.value()).attr(attr_name) = capsule;
}
c10::IntArrayRef ConcretePyInterpreterVTable::strides(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
@ -566,38 +577,20 @@ c10::IntArrayRef ConcretePyInterpreterVTable::strides(
"Cannot call strides on a tensor with symbolic shapes/strides");
return self->strides_default();
}
py::object values = py::reinterpret_steal<py::object>(out.ptr());
c10::optional<PyObject*> mb_obj =
self->pyobj_slot()->check_pyobj(getPyInterpreter());
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
PyObject* subclass = *mb_obj;
Py_INCREF(subclass);
py::object sub = py::reinterpret_steal<py::object>(subclass);
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
"strides must be a list or a tuple");
py::object os = py::module_::import("torch").attr("overrides");
py::function get_buffer =
py::reinterpret_borrow<py::function>(os.attr("get_buffer"));
auto buffer = get_buffer(sub, values, "stride");
auto result = THPUtils_unpackLongs(buffer.ptr());
// NOLINTNEXTLINE(performance-no-int-to-ptr)
int64_t* start = reinterpret_cast<int64_t*>(result[0]);
int64_t len = result[1];
return c10::IntArrayRef(start, len);
}
static void set_tensor_attr_with_capsule(
const c10::TensorImpl* tensor,
py::capsule& capsule,
const char* attr_name) {
c10::optional<PyObject*> mb_obj =
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
py::handle(mb_obj.value()).attr(attr_name) = capsule;
size_t len = py::len(out);
int64_t* ptr = new int64_t[len];
auto capsule =
py::capsule(ptr, [](void* p) { delete[] reinterpret_cast<int64_t*>(p); });
int64_t idx = 0;
for (auto it = out.begin(); it != out.end(); ++it, ++idx) {
ptr[idx] = py::cast<int64_t>(*it);
}
set_tensor_attr_with_capsule(self, capsule, "_sizes_capsule");
return c10::IntArrayRef(ptr, len);
}
c10::IntArrayRef ConcretePyInterpreterVTable::sizes(

View file

@ -720,9 +720,8 @@ static PyObject* THPVariable_make_wrapper_subclass(
const auto sizes_strides_policy = r.stringViewOptional(10);
if (sizes_strides_policy.has_value()) {
TORCH_CHECK(
false,
"Setting sizes_strides_policy isn't supported for this overload")
tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
parseSizesStridesPolicyArgument(*sizes_strides_policy));
}
}