mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
add dynamic shapes support for subclasses that override size/stride (#107916)
This is mostly a minor fix on top of @soulitzer's PR https://github.com/pytorch/pytorch/pull/107839. (1) `strides` wasn't going through the new `set_tensor_attr_with_capsule` flow (2) The dynamic shapes overload for `_make_wrapper_subclass` currently errors when you try to use custom sizes - I removed the error (3) added a test I need this later because I'm adding a `__torch_dispatch__` `FunctionalTensor` wrapper subclass, that needs to support dynamic shapes, and also plumb metadata calls to its inner tensor later. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107916 Approved by: https://github.com/ezyang, https://github.com/soulitzer ghstack dependencies: #107915
This commit is contained in:
parent
4f34caf164
commit
c6e3adaf54
3 changed files with 72 additions and 33 deletions
|
|
@ -2029,6 +2029,53 @@ $0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), p
|
|||
e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
|
||||
self.assertEqual(e.size(), (4, 2))
|
||||
|
||||
def test_custom_size_policy_dynamic_shapes(self):
|
||||
data = torch.randn(6, 2)
|
||||
|
||||
class CustomSizeDynamicShapesTensor(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, inner):
|
||||
return torch.Tensor._make_wrapper_subclass(
|
||||
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
|
||||
# Calling the overload that has kwargs causes us to go down the first overload path,
|
||||
# which will **always** specialize sizes.
|
||||
# We should probably eventually fix this so that the first overload can just handle dynamic shapes.
|
||||
cls,
|
||||
inner.size(),
|
||||
inner.stride(),
|
||||
None,
|
||||
None,
|
||||
inner.dtype,
|
||||
inner.layout,
|
||||
inner.device,
|
||||
False,
|
||||
inner.requires_grad,
|
||||
"sizes",
|
||||
)
|
||||
|
||||
def __init__(self, inner):
|
||||
self.inner = inner
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func == torch.ops.aten.sym_size.default:
|
||||
return args[0].inner.shape
|
||||
if func == torch.ops.aten.sym_stride.default:
|
||||
return args[0].inner.shape
|
||||
return NotImplemented
|
||||
|
||||
x = torch.ones(2, 2)
|
||||
|
||||
def trace_fn(x):
|
||||
x_wrapper = CustomSizeDynamicShapesTensor(x)
|
||||
return x_wrapper.size(), x_wrapper.stride()
|
||||
fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x)
|
||||
self.assertExpectedInline(fx_g.code.strip(), """\
|
||||
def forward(self, x_1):
|
||||
sym_size = torch.ops.aten.sym_size(x_1, 0)
|
||||
sym_size_1 = torch.ops.aten.sym_size(x_1, 1); x_1 = None
|
||||
return ((sym_size, sym_size_1), (sym_size, sym_size_1))""")
|
||||
|
||||
def test_data_ptr_respects_numel_slow_path(self):
|
||||
data = torch.randn(6, 2)
|
||||
|
||||
|
|
|
|||
|
|
@ -544,6 +544,17 @@ c10::Device ConcretePyInterpreterVTable::device(
|
|||
return toDevice(out.ptr());
|
||||
}
|
||||
|
||||
static void set_tensor_attr_with_capsule(
|
||||
const c10::TensorImpl* tensor,
|
||||
py::capsule& capsule,
|
||||
const char* attr_name) {
|
||||
c10::optional<PyObject*> mb_obj =
|
||||
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
|
||||
TORCH_CHECK(
|
||||
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
||||
py::handle(mb_obj.value()).attr(attr_name) = capsule;
|
||||
}
|
||||
|
||||
c10::IntArrayRef ConcretePyInterpreterVTable::strides(
|
||||
const c10::TensorImpl* self) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
|
|
@ -566,38 +577,20 @@ c10::IntArrayRef ConcretePyInterpreterVTable::strides(
|
|||
"Cannot call strides on a tensor with symbolic shapes/strides");
|
||||
return self->strides_default();
|
||||
}
|
||||
|
||||
py::object values = py::reinterpret_steal<py::object>(out.ptr());
|
||||
|
||||
c10::optional<PyObject*> mb_obj =
|
||||
self->pyobj_slot()->check_pyobj(getPyInterpreter());
|
||||
TORCH_CHECK(
|
||||
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
||||
PyObject* subclass = *mb_obj;
|
||||
Py_INCREF(subclass);
|
||||
py::object sub = py::reinterpret_steal<py::object>(subclass);
|
||||
py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out),
|
||||
"strides must be a list or a tuple");
|
||||
|
||||
py::object os = py::module_::import("torch").attr("overrides");
|
||||
py::function get_buffer =
|
||||
py::reinterpret_borrow<py::function>(os.attr("get_buffer"));
|
||||
auto buffer = get_buffer(sub, values, "stride");
|
||||
auto result = THPUtils_unpackLongs(buffer.ptr());
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
int64_t* start = reinterpret_cast<int64_t*>(result[0]);
|
||||
int64_t len = result[1];
|
||||
|
||||
return c10::IntArrayRef(start, len);
|
||||
}
|
||||
|
||||
static void set_tensor_attr_with_capsule(
|
||||
const c10::TensorImpl* tensor,
|
||||
py::capsule& capsule,
|
||||
const char* attr_name) {
|
||||
c10::optional<PyObject*> mb_obj =
|
||||
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
|
||||
TORCH_CHECK(
|
||||
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
||||
py::handle(mb_obj.value()).attr(attr_name) = capsule;
|
||||
size_t len = py::len(out);
|
||||
int64_t* ptr = new int64_t[len];
|
||||
auto capsule =
|
||||
py::capsule(ptr, [](void* p) { delete[] reinterpret_cast<int64_t*>(p); });
|
||||
int64_t idx = 0;
|
||||
for (auto it = out.begin(); it != out.end(); ++it, ++idx) {
|
||||
ptr[idx] = py::cast<int64_t>(*it);
|
||||
}
|
||||
set_tensor_attr_with_capsule(self, capsule, "_sizes_capsule");
|
||||
return c10::IntArrayRef(ptr, len);
|
||||
}
|
||||
|
||||
c10::IntArrayRef ConcretePyInterpreterVTable::sizes(
|
||||
|
|
|
|||
|
|
@ -720,9 +720,8 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
|||
|
||||
const auto sizes_strides_policy = r.stringViewOptional(10);
|
||||
if (sizes_strides_policy.has_value()) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Setting sizes_strides_policy isn't supported for this overload")
|
||||
tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
|
||||
parseSizesStridesPolicyArgument(*sizes_strides_policy));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue