From 9e6e9587c1ddcb05334984f4a0338f757c686c19 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 7 Nov 2023 20:17:30 +0000 Subject: [PATCH] Make numel/sym_numel PyInterpreter work symmetrically to others (#113065) Just some better engineering code cleanup. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/113065 Approved by: https://github.com/voznesenskym --- c10/core/TensorImpl.cpp | 3 +-- c10/core/impl/PyInterpreter.cpp | 3 +++ c10/core/impl/PyInterpreter.h | 1 + test/test_python_dispatch.py | 2 +- torch/csrc/PyInterpreter.cpp | 27 ++++++++++++++++++++++++--- 5 files changed, 30 insertions(+), 6 deletions(-) diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index f9ff722df49..aa74c784fad 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -689,8 +689,7 @@ int64_t TensorImpl::dim_custom() const { int64_t TensorImpl::numel_custom() const { if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) { - // TODO: fix this - return pyobj_slot_.load_pyobj_interpreter()->sym_numel(this).expect_int(); + return pyobj_slot_.load_pyobj_interpreter()->numel(this); } return numel_default(); } diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp index 20c49b94b26..f555ee5c345 100644 --- a/c10/core/impl/PyInterpreter.cpp +++ b/c10/core/impl/PyInterpreter.cpp @@ -81,6 +81,9 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable { c10::Layout layout(const TensorImpl* self) const override { PANIC(layout); } + int64_t numel(const TensorImpl* self) const override { + PANIC(numel); + } c10::SymInt sym_numel(const TensorImpl* self) const override { PANIC(sym_numel); } diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h index a5faa2fd377..ab5cb6c5e5b 100644 --- a/c10/core/impl/PyInterpreter.h +++ b/c10/core/impl/PyInterpreter.h @@ -174,6 +174,7 @@ struct C10_API PyInterpreterVTable { virtual c10::IntArrayRef sizes(const TensorImpl* self) const = 0; virtual c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const = 0; virtual c10::Layout layout(const TensorImpl* self) const = 0; + virtual int64_t numel(const TensorImpl* self) const = 0; virtual c10::SymInt sym_numel(const TensorImpl* self) const = 0; virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0; virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0; diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 3790ec6ed70..622f3b8e912 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -2110,7 +2110,7 @@ def forward(self, x_1): def __torch_dispatch__(cls, func, types, args, kwargs): if func.overloadpacket == torch.ops.aten.dim: return data.dim() - if func.overloadpacket == torch.ops.aten.sym_numel: + if func.overloadpacket == torch.ops.aten.numel: numel_called[0] = True return None return NotImplemented diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index d5776ae62d6..3cd16ea7b9a 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -78,6 +78,7 @@ struct ConcretePyInterpreterVTable final c10::IntArrayRef sizes(const c10::TensorImpl* self) const override; c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override; c10::Layout layout(const c10::TensorImpl* self) const override; + int64_t numel(const c10::TensorImpl* self) const override; c10::SymInt sym_numel(const c10::TensorImpl* self) const override; c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override; c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override; @@ -814,6 +815,29 @@ c10::Layout ConcretePyInterpreterVTable::layout( return toLayout(out.ptr()); } +int64_t ConcretePyInterpreterVTable::numel(const c10::TensorImpl* self) const { + pybind11::gil_scoped_acquire gil; + at::impl::MaybeSetTLSOnEntryGuard guard; + auto out = torchDispatchFromTensorImpl( + self, + "numel", + py::module::import("torch") + .attr("ops") + .attr("aten") + .attr("numel") + .attr("default") + .ptr(), + "torch.ops.aten"); + + if (out.is_none()) { + TORCH_CHECK( + !self->has_symbolic_sizes_strides(), + "Cannot call sizes on a tensor with symbolic shapes/strides"); + return self->numel_default(); + } + return py::cast(out); +} + c10::SymInt ConcretePyInterpreterVTable::sym_numel( const c10::TensorImpl* self) const { pybind11::gil_scoped_acquire gil; @@ -830,9 +854,6 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel( "torch.ops.aten"); if (out.is_none()) { - TORCH_CHECK( - !self->has_symbolic_sizes_strides(), - "Cannot call numel on a tensor with symbolic shapes/strides"); return self->sym_numel_default(); } return torch::is_symint(out) ? out.cast()