mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
add sizes to slowpath
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79295 Approved by: https://github.com/ezyang
This commit is contained in:
parent
543919cfc8
commit
05624bcf7b
8 changed files with 134 additions and 14 deletions
|
|
@ -384,6 +384,9 @@ bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
|
|||
}
|
||||
|
||||
IntArrayRef TensorImpl::sizes_custom() const {
|
||||
if (is_python_dispatch()) {
|
||||
return load_pyobj_interpreter()->sizes(this);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false, "Tensors of type ", tensorimpl_type_name(), " do not have sizes");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -661,6 +661,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
sizes_and_strides_.size());
|
||||
}
|
||||
|
||||
inline IntArrayRef sizes_default() const {
|
||||
return c10::IntArrayRef(
|
||||
reinterpret_cast<const int64_t*>(sizes_and_strides_.sizes_data()),
|
||||
sizes_and_strides_.size());
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* Customization points for the functions above. sizes_strides_policy_
|
||||
|
|
@ -691,11 +697,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
}
|
||||
return is_contiguous_;
|
||||
}
|
||||
inline IntArrayRef sizes_default() const {
|
||||
return c10::IntArrayRef(
|
||||
reinterpret_cast<const int64_t*>(sizes_and_strides_.sizes_data()),
|
||||
sizes_and_strides_.size());
|
||||
}
|
||||
inline c10::SymIntArrayRef sym_sizes_default() const {
|
||||
return c10::SymIntArrayRef(
|
||||
reinterpret_cast<const c10::SymInt*>(sizes_and_strides_.sizes_data()),
|
||||
|
|
|
|||
|
|
@ -56,6 +56,12 @@ static c10::IntArrayRef noop_strides_fn(
|
|||
"attempted to call `strides` on Tensor with nontrivial PyObject after corresponding interpreter died");
|
||||
}
|
||||
|
||||
static c10::IntArrayRef noop_sizes_fn(const PyInterpreter*, const TensorImpl*) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
0,
|
||||
"attempted to call `sizes` on Tensor with nontrivial PyObject after corresponding interpreter died");
|
||||
}
|
||||
|
||||
void PyInterpreter::disarm() noexcept {
|
||||
name_fn_ = &noop_name_fn;
|
||||
decref_fn_ = &noop_decref_fn;
|
||||
|
|
@ -65,6 +71,7 @@ void PyInterpreter::disarm() noexcept {
|
|||
device_fn_ = &noop_device_fn;
|
||||
dim_fn_ = &noop_dim_fn;
|
||||
strides_fn_ = &noop_strides_fn;
|
||||
sizes_fn_ = &noop_sizes_fn;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
|
|
|||
|
|
@ -132,6 +132,7 @@ struct C10_API PyInterpreter {
|
|||
using device_sig = c10::Device(const PyInterpreter*, const TensorImpl*);
|
||||
using dim_sig = int64_t(const PyInterpreter*, const TensorImpl*);
|
||||
using strides_sig = c10::IntArrayRef(const PyInterpreter*, const TensorImpl*);
|
||||
using sizes_sig = c10::IntArrayRef(const PyInterpreter*, const TensorImpl*);
|
||||
|
||||
PyInterpreter(
|
||||
name_sig* name_fn,
|
||||
|
|
@ -141,7 +142,8 @@ struct C10_API PyInterpreter {
|
|||
is_contiguous_sig* is_contiguous,
|
||||
device_sig* device_fn,
|
||||
dim_sig* dim_fn,
|
||||
strides_sig* strides)
|
||||
strides_sig* strides,
|
||||
sizes_sig* sizes)
|
||||
: name_fn_(name_fn),
|
||||
decref_fn_(decref_fn),
|
||||
detach_fn_(detach),
|
||||
|
|
@ -149,7 +151,8 @@ struct C10_API PyInterpreter {
|
|||
is_contiguous_fn_(is_contiguous),
|
||||
device_fn_(device_fn),
|
||||
dim_fn_(dim_fn),
|
||||
strides_fn_(strides) {}
|
||||
strides_fn_(strides),
|
||||
sizes_fn_(sizes) {}
|
||||
|
||||
name_sig* name_fn_;
|
||||
decref_sig* decref_fn_;
|
||||
|
|
@ -159,6 +162,7 @@ struct C10_API PyInterpreter {
|
|||
device_sig* device_fn_;
|
||||
dim_sig* dim_fn_;
|
||||
strides_sig* strides_fn_;
|
||||
sizes_sig* sizes_fn_;
|
||||
|
||||
// UBSAN suppression fixes: "call to function
|
||||
// (anonymous namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*,
|
||||
|
|
@ -210,6 +214,11 @@ struct C10_API PyInterpreter {
|
|||
return (*strides_fn_)(this, self);
|
||||
}
|
||||
|
||||
__ubsan_ignore_function__ c10::IntArrayRef sizes(
|
||||
const TensorImpl* self) const {
|
||||
return (*sizes_fn_)(this, self);
|
||||
}
|
||||
|
||||
// Disarm this PyInterpreter, making all of its methods noops.
|
||||
// Because the function pointers are raw pointers (not atomics),
|
||||
// a disarm() invocation that is concurrent with active destructors
|
||||
|
|
|
|||
|
|
@ -1645,5 +1645,57 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
|||
e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass)
|
||||
self.assertEqual(e.stride(), (2, 1))
|
||||
|
||||
def test_sizes_slow_path(self):
|
||||
for use_wrapper_subclass in [True, False]:
|
||||
data = torch.randn(6, 2)
|
||||
|
||||
class SizesNotImplemented(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data, wrapper):
|
||||
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.dim:
|
||||
return data.dim()
|
||||
return NotImplemented
|
||||
|
||||
class SizesCustomReturn(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data, wrapper):
|
||||
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.dim:
|
||||
return data.dim()
|
||||
if func.overloadpacket == torch.ops.aten.size:
|
||||
return (5, 3)
|
||||
return NotImplemented
|
||||
|
||||
class SizesDefaultReturn(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, data, wrapper):
|
||||
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="sizes")
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.dim:
|
||||
return data.dim()
|
||||
if func.overloadpacket == torch.ops.aten.size:
|
||||
return None
|
||||
return NotImplemented
|
||||
|
||||
err_msg = "no implementation found for 'torch.ops.aten.size'"
|
||||
e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
e.size()
|
||||
|
||||
e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
|
||||
self.assertEqual(e.size(), (5, 3))
|
||||
|
||||
e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
|
||||
self.assertEqual(e.size(), (4, 2))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
2
third_party/ideep
vendored
2
third_party/ideep
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit 8a114a51c116b55c4ceb689b98746786bd00c29b
|
||||
Subproject commit 02b17c5748c9349dcc586c359af800c684d9b1ab
|
||||
|
|
@ -251,6 +251,9 @@ int64_t concrete_dim_fn(
|
|||
c10::IntArrayRef concrete_strides_fn(
|
||||
const c10::impl::PyInterpreter*,
|
||||
const c10::TensorImpl* self);
|
||||
c10::IntArrayRef concrete_sizes_fn(
|
||||
const c10::impl::PyInterpreter*,
|
||||
const c10::TensorImpl* self);
|
||||
|
||||
class PyInterpreterHolder {
|
||||
public:
|
||||
|
|
@ -263,7 +266,8 @@ class PyInterpreterHolder {
|
|||
&concrete_is_contiguous_fn,
|
||||
&concrete_device_fn,
|
||||
&concrete_dim_fn,
|
||||
&concrete_strides_fn)) {}
|
||||
&concrete_strides_fn,
|
||||
&concrete_sizes_fn)) {}
|
||||
// NB: intentionally leaks the memory
|
||||
~PyInterpreterHolder() {
|
||||
impl_->disarm();
|
||||
|
|
@ -2279,7 +2283,49 @@ c10::IntArrayRef concrete_strides_fn(
|
|||
py::object os = py::module_::import("torch").attr("overrides");
|
||||
py::function get_buffer =
|
||||
py::reinterpret_borrow<py::function>(os.attr("get_buffer"));
|
||||
auto buffer = get_buffer(sub, values);
|
||||
auto buffer = get_buffer(sub, values, "stride");
|
||||
auto result = THPUtils_unpackLongs(buffer.ptr());
|
||||
int64_t* start = (int64_t*)result[0];
|
||||
int64_t len = result[1];
|
||||
|
||||
return c10::IntArrayRef(start, len);
|
||||
}
|
||||
|
||||
c10::IntArrayRef concrete_sizes_fn(
|
||||
const c10::impl::PyInterpreter*,
|
||||
const c10::TensorImpl* self) {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
|
||||
auto out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"size",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
.attr("size")
|
||||
.attr("default")
|
||||
.ptr(),
|
||||
"torch.ops.aten");
|
||||
|
||||
if (out == Py_None) {
|
||||
return self->sizes_default();
|
||||
}
|
||||
|
||||
py::object values = py::reinterpret_steal<py::object>(out.ptr());
|
||||
|
||||
c10::TensorImpl* ptr = const_cast<c10::TensorImpl*>(self);
|
||||
c10::optional<PyObject*> mb_obj = ptr->check_pyobj(getPyInterpreter());
|
||||
TORCH_CHECK(
|
||||
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
||||
PyObject* subclass = *mb_obj;
|
||||
Py_INCREF(subclass);
|
||||
py::object sub = py::reinterpret_steal<py::object>(subclass);
|
||||
|
||||
py::object os = py::module_::import("torch").attr("overrides");
|
||||
py::function get_buffer =
|
||||
py::reinterpret_borrow<py::function>(os.attr("get_buffer"));
|
||||
auto buffer = get_buffer(sub, values, "size");
|
||||
auto result = THPUtils_unpackLongs(buffer.ptr());
|
||||
int64_t* start = (int64_t*)result[0];
|
||||
int64_t len = result[1];
|
||||
|
|
|
|||
|
|
@ -1954,10 +1954,12 @@ class enable_reentrant_dispatch():
|
|||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
del self._raii_guard
|
||||
|
||||
def get_buffer(tensor_subclass, data):
|
||||
def get_buffer(tensor_subclass, data, prefix):
|
||||
import ctypes
|
||||
if not hasattr(tensor_subclass, "_stride_buffer"):
|
||||
assert prefix in {"stride", "size"}
|
||||
buffer_name = f"_{prefix}_buffer"
|
||||
if not hasattr(tensor_subclass, buffer_name):
|
||||
SizeType = ctypes.c_longlong * len(data)
|
||||
tensor_subclass._stride_buffer = SizeType(*data)
|
||||
ptr = ctypes.addressof(tensor_subclass._stride_buffer)
|
||||
setattr(tensor_subclass, buffer_name, SizeType(*data))
|
||||
ptr = ctypes.addressof(getattr(tensor_subclass, buffer_name))
|
||||
return (ptr, len(data))
|
||||
|
|
|
|||
Loading…
Reference in a new issue