mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Make numel/sym_numel PyInterpreter work symmetrically to others (#113065)
Just some better engineering code cleanup. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/113065 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
78b8465565
commit
9e6e9587c1
5 changed files with 30 additions and 6 deletions
|
|
@ -689,8 +689,7 @@ int64_t TensorImpl::dim_custom() const {
|
|||
|
||||
int64_t TensorImpl::numel_custom() const {
|
||||
if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
|
||||
// TODO: fix this
|
||||
return pyobj_slot_.load_pyobj_interpreter()->sym_numel(this).expect_int();
|
||||
return pyobj_slot_.load_pyobj_interpreter()->numel(this);
|
||||
}
|
||||
return numel_default();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,6 +81,9 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
|||
c10::Layout layout(const TensorImpl* self) const override {
|
||||
PANIC(layout);
|
||||
}
|
||||
int64_t numel(const TensorImpl* self) const override {
|
||||
PANIC(numel);
|
||||
}
|
||||
c10::SymInt sym_numel(const TensorImpl* self) const override {
|
||||
PANIC(sym_numel);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -174,6 +174,7 @@ struct C10_API PyInterpreterVTable {
|
|||
virtual c10::IntArrayRef sizes(const TensorImpl* self) const = 0;
|
||||
virtual c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const = 0;
|
||||
virtual c10::Layout layout(const TensorImpl* self) const = 0;
|
||||
virtual int64_t numel(const TensorImpl* self) const = 0;
|
||||
virtual c10::SymInt sym_numel(const TensorImpl* self) const = 0;
|
||||
virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0;
|
||||
virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0;
|
||||
|
|
|
|||
|
|
@ -2110,7 +2110,7 @@ def forward(self, x_1):
|
|||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if func.overloadpacket == torch.ops.aten.dim:
|
||||
return data.dim()
|
||||
if func.overloadpacket == torch.ops.aten.sym_numel:
|
||||
if func.overloadpacket == torch.ops.aten.numel:
|
||||
numel_called[0] = True
|
||||
return None
|
||||
return NotImplemented
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ struct ConcretePyInterpreterVTable final
|
|||
c10::IntArrayRef sizes(const c10::TensorImpl* self) const override;
|
||||
c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override;
|
||||
c10::Layout layout(const c10::TensorImpl* self) const override;
|
||||
int64_t numel(const c10::TensorImpl* self) const override;
|
||||
c10::SymInt sym_numel(const c10::TensorImpl* self) const override;
|
||||
c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override;
|
||||
c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override;
|
||||
|
|
@ -814,6 +815,29 @@ c10::Layout ConcretePyInterpreterVTable::layout(
|
|||
return toLayout(out.ptr());
|
||||
}
|
||||
|
||||
int64_t ConcretePyInterpreterVTable::numel(const c10::TensorImpl* self) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
at::impl::MaybeSetTLSOnEntryGuard guard;
|
||||
auto out = torchDispatchFromTensorImpl(
|
||||
self,
|
||||
"numel",
|
||||
py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr("aten")
|
||||
.attr("numel")
|
||||
.attr("default")
|
||||
.ptr(),
|
||||
"torch.ops.aten");
|
||||
|
||||
if (out.is_none()) {
|
||||
TORCH_CHECK(
|
||||
!self->has_symbolic_sizes_strides(),
|
||||
"Cannot call sizes on a tensor with symbolic shapes/strides");
|
||||
return self->numel_default();
|
||||
}
|
||||
return py::cast<int64_t>(out);
|
||||
}
|
||||
|
||||
c10::SymInt ConcretePyInterpreterVTable::sym_numel(
|
||||
const c10::TensorImpl* self) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
|
|
@ -830,9 +854,6 @@ c10::SymInt ConcretePyInterpreterVTable::sym_numel(
|
|||
"torch.ops.aten");
|
||||
|
||||
if (out.is_none()) {
|
||||
TORCH_CHECK(
|
||||
!self->has_symbolic_sizes_strides(),
|
||||
"Cannot call numel on a tensor with symbolic shapes/strides");
|
||||
return self->sym_numel_default();
|
||||
}
|
||||
return torch::is_symint(out) ? out.cast<c10::SymInt>()
|
||||
|
|
|
|||
Loading…
Reference in a new issue