From be02103786d7e455a010df3bd9f580fa6c0c8d37 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 5 Oct 2023 11:54:26 -0700 Subject: [PATCH] [BE] Get rid of code duplication (#110619) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace `dispatch_to_CDouble`, `dispatch_to_CLong` and `dispatch_to_CComplexDouble` with `dispatch_to` template ### 🤖 Generated by Copilot at c3d9d01 > _Sing, O Muse, of the clever coder who devised_ > _A wondrous template function, `dispatch_to`, that could_ > _Handle with ease the various scalar types that vexed_ > _The previous code, which was verbose and dull as wood._ Pull Request resolved: https://github.com/pytorch/pytorch/pull/110619 Approved by: https://github.com/soulitzer, https://github.com/albanD ghstack dependencies: #110618 --- .../templates/python_variable_methods.cpp | 33 +++++-------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index e39293e0538..ea5cded872d 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -315,31 +315,14 @@ static Tensor dispatch_copy_(const Tensor & self, const Tensor & other, bool non END_HANDLE_TH_ERRORS } -static double dispatch_to_CDouble(const Tensor & self) { +template +static T dispatch_to(const Tensor & self) { pybind11::gil_scoped_release no_gil; OptionalDeviceGuard device_guard(device_of(self)); if (self.sym_numel() != 1) { throw ValueError("only one element tensors can be converted to Python scalars"); } - return self.item(); -} - -static c10::complex dispatch_to_CComplexDouble(const Tensor & self) { - pybind11::gil_scoped_release no_gil; - OptionalDeviceGuard device_guard(device_of(self)); - if (self.sym_numel() != 1) { - throw ValueError("only one element tensors can be converted to Python scalars"); - } - return self.item>(); -} - -static int64_t dispatch_to_CLong(const Tensor & self) { - pybind11::gil_scoped_release no_gil; - OptionalDeviceGuard device_guard(device_of(self)); - if (self.sym_numel() != 1) { - throw ValueError("only one element tensors can be converted to Python scalars"); - } - return self.item(); + return self.template item(); } static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { @@ -349,7 +332,7 @@ static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { } jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = THPVariable_Unpack(self); - return wrap(dispatch_to_CDouble(self_)); + return wrap(dispatch_to(self_)); END_HANDLE_TH_ERRORS } @@ -360,7 +343,7 @@ static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) { } jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = THPVariable_Unpack(self); - return wrap(dispatch_to_CComplexDouble(self_)); + return wrap(dispatch_to>(self_)); END_HANDLE_TH_ERRORS } @@ -374,9 +357,9 @@ static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { if (isFloatingType(self_.scalar_type())) { // we can't dispatch to item here because we want to avoid ATen overflow checks; // the python integral type (long in python2) can't overflow. - return THPUtils_packDoubleAsInt(dispatch_to_CDouble(self_)); + return THPUtils_packDoubleAsInt(dispatch_to(self_)); } else { - return wrap(dispatch_to_CLong(self_)); + return wrap(dispatch_to(self_)); } END_HANDLE_TH_ERRORS } @@ -394,7 +377,7 @@ static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) { if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true) || self_.sym_numel() != 1) { throw TypeError("only integer tensors of a single element can be converted to an index"); } - return wrap(dispatch_to_CLong(self_)); + return wrap(dispatch_to(self_)); END_HANDLE_TH_ERRORS }