Make numel/sym_numel PyInterpreter work symmetrically to others (#113065)

Just some better engineering code cleanup.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113065
Approved by: https://github.com/voznesenskym
This commit is contained in:
Edward Z. Yang 2023-11-07 20:17:30 +00:00 committed by PyTorch MergeBot
parent 78b8465565
commit 9e6e9587c1
5 changed files with 30 additions and 6 deletions

View file

@ -689,8 +689,7 @@ int64_t TensorImpl::dim_custom() const {
int64_t TensorImpl::numel_custom() const {
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
// TODO: fix this
return pyobj_slot_.load_pyobj_interpreter()->sym_numel(this).expect_int();
return pyobj_slot_.load_pyobj_interpreter()->numel(this);
}
return numel_default();
}

View file

@ -81,6 +81,9 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
c10::Layout layout(const TensorImpl* self) const override {
PANIC(layout);
}
int64_t numel(const TensorImpl* self) const override {
PANIC(numel);
}
c10::SymInt sym_numel(const TensorImpl* self) const override {
PANIC(sym_numel);
}

View file

@ -174,6 +174,7 @@ struct C10_API PyInterpreterVTable {
virtual c10::IntArrayRef sizes(const TensorImpl* self) const = 0;
virtual c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const = 0;
virtual c10::Layout layout(const TensorImpl* self) const = 0;
virtual int64_t numel(const TensorImpl* self) const = 0;
virtual c10::SymInt sym_numel(const TensorImpl* self) const = 0;
virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0;
virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0;

View file

@ -2110,7 +2110,7 @@ def forward(self, x_1):
def __torch_dispatch__(cls, func, types, args, kwargs):
if func.overloadpacket == torch.ops.aten.dim:
return data.dim()
if func.overloadpacket == torch.ops.aten.sym_numel:
if func.overloadpacket == torch.ops.aten.numel:
numel_called[0] = True
return None
return NotImplemented

View file

@ -78,6 +78,7 @@ struct ConcretePyInterpreterVTable final
c10::IntArrayRef sizes(const c10::TensorImpl* self) const override;
c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override;
c10::Layout layout(const c10::TensorImpl* self) const override;
int64_t numel(const c10::TensorImpl* self) const override;
c10::SymInt sym_numel(const c10::TensorImpl* self) const override;
c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override;
c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override;
@ -814,6 +815,29 @@ c10::Layout ConcretePyInterpreterVTable::layout(
return toLayout(out.ptr());
}
int64_t ConcretePyInterpreterVTable::numel(const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
at::impl::MaybeSetTLSOnEntryGuard guard;
auto out = torchDispatchFromTensorImpl(
self,
"numel",
py::module::import("torch")
.attr("ops")
.attr("aten")
.attr("numel")
.attr("default")
.ptr(),
"torch.ops.aten");
if (out.is_none()) {
TORCH_CHECK(
!self->has_symbolic_sizes_strides(),
"Cannot call sizes on a tensor with symbolic shapes/strides");
return self->numel_default();
}
return py::cast<int64_t>(out);
}
c10::SymInt ConcretePyInterpreterVTable::sym_numel(
const c10::TensorImpl* self) const {
pybind11::gil_scoped_acquire gil;
@ -830,9 +854,6 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel(
"torch.ops.aten");
if (out.is_none()) {
TORCH_CHECK(
!self->has_symbolic_sizes_strides(),
"Cannot call numel on a tensor with symbolic shapes/strides");
return self->sym_numel_default();
}
return torch::is_symint(out) ? out.cast<c10::SymInt>()