add sizes to slowpath

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79295

Approved by: https://github.com/ezyang
This commit is contained in:
George Qi 2022-06-13 18:07:07 +00:00 committed by PyTorch MergeBot
parent 543919cfc8
commit 05624bcf7b
8 changed files with 134 additions and 14 deletions

View file

@ -384,6 +384,9 @@ bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
}
IntArrayRef TensorImpl::sizes_custom() const {
if (is_python_dispatch()) {
return load_pyobj_interpreter()->sizes(this);
}
TORCH_CHECK(
false, "Tensors of type ", tensorimpl_type_name(), " do not have sizes");
}

View file

@ -661,6 +661,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
sizes_and_strides_.size());
}
inline IntArrayRef sizes_default() const {
return c10::IntArrayRef(
reinterpret_cast<const int64_t*>(sizes_and_strides_.sizes_data()),
sizes_and_strides_.size());
}
protected:
/**
* Customization points for the functions above. sizes_strides_policy_
@ -691,11 +697,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
}
return is_contiguous_;
}
inline IntArrayRef sizes_default() const {
return c10::IntArrayRef(
reinterpret_cast<const int64_t*>(sizes_and_strides_.sizes_data()),
sizes_and_strides_.size());
}
inline c10::SymIntArrayRef sym_sizes_default() const {
return c10::SymIntArrayRef(
reinterpret_cast<const c10::SymInt*>(sizes_and_strides_.sizes_data()),

View file

@ -56,6 +56,12 @@ static c10::IntArrayRef noop_strides_fn(
"attempted to call `strides` on Tensor with nontrivial PyObject after corresponding interpreter died");
}
static c10::IntArrayRef noop_sizes_fn(const PyInterpreter*, const TensorImpl*) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to call `sizes` on Tensor with nontrivial PyObject after corresponding interpreter died");
}
void PyInterpreter::disarm() noexcept {
name_fn_ = &noop_name_fn;
decref_fn_ = &noop_decref_fn;
@ -65,6 +71,7 @@ void PyInterpreter::disarm() noexcept {
device_fn_ = &noop_device_fn;
dim_fn_ = &noop_dim_fn;
strides_fn_ = &noop_strides_fn;
sizes_fn_ = &noop_sizes_fn;
}
} // namespace impl

View file

@ -132,6 +132,7 @@ struct C10_API PyInterpreter {
using device_sig = c10::Device(const PyInterpreter*, const TensorImpl*);
using dim_sig = int64_t(const PyInterpreter*, const TensorImpl*);
using strides_sig = c10::IntArrayRef(const PyInterpreter*, const TensorImpl*);
using sizes_sig = c10::IntArrayRef(const PyInterpreter*, const TensorImpl*);
PyInterpreter(
name_sig* name_fn,
@ -141,7 +142,8 @@ struct C10_API PyInterpreter {
is_contiguous_sig* is_contiguous,
device_sig* device_fn,
dim_sig* dim_fn,
strides_sig* strides)
strides_sig* strides,
sizes_sig* sizes)
: name_fn_(name_fn),
decref_fn_(decref_fn),
detach_fn_(detach),
@ -149,7 +151,8 @@ struct C10_API PyInterpreter {
is_contiguous_fn_(is_contiguous),
device_fn_(device_fn),
dim_fn_(dim_fn),
strides_fn_(strides) {}
strides_fn_(strides),
sizes_fn_(sizes) {}
name_sig* name_fn_;
decref_sig* decref_fn_;
@ -159,6 +162,7 @@ struct C10_API PyInterpreter {
device_sig* device_fn_;
dim_sig* dim_fn_;
strides_sig* strides_fn_;
sizes_sig* sizes_fn_;
// UBSAN suppression fixes: "call to function
// (anonymous namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*,
@ -210,6 +214,11 @@ struct C10_API PyInterpreter {
return (*strides_fn_)(this, self);
}
__ubsan_ignore_function__ c10::IntArrayRef sizes(
const TensorImpl* self) const {
return (*sizes_fn_)(this, self);
}
// Disarm this PyInterpreter, making all of its methods noops.
// Because the function pointers are raw pointers (not atomics),
// a disarm() invocation that is concurrent with active destructors

View file

@ -1645,5 +1645,57 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass)
self.assertEqual(e.stride(), (2, 1))
def test_sizes_slow_path(self):
for use_wrapper_subclass in [True, False]:
data = torch.randn(6, 2)
class SizesNotImplemented(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
return NotImplemented
class SizesCustomReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.size:
return (5, 3)
return NotImplemented
class SizesDefaultReturn(torch.Tensor):
@staticmethod
def __new__(cls, data, wrapper):
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.size:
return None
return NotImplemented
err_msg = "no implementation found for 'torch.ops.aten.size'"
e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
with self.assertRaisesRegex(TypeError, err_msg):
e.size()
e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
self.assertEqual(e.size(), (5, 3))
e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
self.assertEqual(e.size(), (4, 2))
if __name__ == '__main__':
run_tests()

2
third_party/ideep vendored

@ -1 +1 @@
Subproject commit 8a114a51c116b55c4ceb689b98746786bd00c29b
Subproject commit 02b17c5748c9349dcc586c359af800c684d9b1ab

View file

@ -251,6 +251,9 @@ int64_t concrete_dim_fn(
c10::IntArrayRef concrete_strides_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
c10::IntArrayRef concrete_sizes_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self);
class PyInterpreterHolder {
public:
@ -263,7 +266,8 @@ class PyInterpreterHolder {
&concrete_is_contiguous_fn,
&concrete_device_fn,
&concrete_dim_fn,
&concrete_strides_fn)) {}
&concrete_strides_fn,
&concrete_sizes_fn)) {}
// NB: intentionally leaks the memory
~PyInterpreterHolder() {
impl_->disarm();
@ -2279,7 +2283,49 @@ c10::IntArrayRef concrete_strides_fn(
py::object os = py::module_::import("torch").attr("overrides");
py::function get_buffer =
py::reinterpret_borrow<py::function>(os.attr("get_buffer"));
auto buffer = get_buffer(sub, values);
auto buffer = get_buffer(sub, values, "stride");
auto result = THPUtils_unpackLongs(buffer.ptr());
int64_t* start = (int64_t*)result[0];
int64_t len = result[1];
return c10::IntArrayRef(start, len);
}
c10::IntArrayRef concrete_sizes_fn(
const c10::impl::PyInterpreter*,
const c10::TensorImpl* self) {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"size",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("size")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out == Py_None) {
return self->sizes_default();
}
py::object values = py::reinterpret_steal<py::object>(out.ptr());
c10::TensorImpl* ptr = const_cast<c10::TensorImpl*>(self);
c10::optional<PyObject*> mb_obj = ptr->check_pyobj(getPyInterpreter());
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
PyObject* subclass = *mb_obj;
Py_INCREF(subclass);
py::object sub = py::reinterpret_steal<py::object>(subclass);
py::object os = py::module_::import("torch").attr("overrides");
py::function get_buffer =
py::reinterpret_borrow<py::function>(os.attr("get_buffer"));
auto buffer = get_buffer(sub, values, "size");
auto result = THPUtils_unpackLongs(buffer.ptr());
int64_t* start = (int64_t*)result[0];
int64_t len = result[1];

View file

@ -1954,10 +1954,12 @@ class enable_reentrant_dispatch():
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
del self._raii_guard
def get_buffer(tensor_subclass, data):
def get_buffer(tensor_subclass, data, prefix):
import ctypes
if not hasattr(tensor_subclass, "_stride_buffer"):
assert prefix in {"stride", "size"}
buffer_name = f"_{prefix}_buffer"
if not hasattr(tensor_subclass, buffer_name):
SizeType = ctypes.c_longlong * len(data)
tensor_subclass._stride_buffer = SizeType(*data)
ptr = ctypes.addressof(tensor_subclass._stride_buffer)
setattr(tensor_subclass, buffer_name, SizeType(*data))
ptr = ctypes.addressof(getattr(tensor_subclass, buffer_name))
return (ptr, len(data))