[BE] Get rid of code duplication (#110619)

Replace `dispatch_to_CDouble`, `dispatch_to_CLong` and `dispatch_to_CComplexDouble` with `dispatch_to<T>` template

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at c3d9d01</samp>

> _Sing, O Muse, of the clever coder who devised_
> _A wondrous template function, `dispatch_to<T>`, that could_
> _Handle with ease the various scalar types that vexed_
> _The previous code, which was verbose and dull as wood._
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110619
Approved by: https://github.com/soulitzer, https://github.com/albanD
ghstack dependencies: #110618
This commit is contained in:
Nikita Shulga 2023-10-05 11:54:26 -07:00 committed by PyTorch MergeBot
parent 82e353fffc
commit be02103786

View file

@ -315,31 +315,14 @@ static Tensor dispatch_copy_(const Tensor & self, const Tensor & other, bool non
END_HANDLE_TH_ERRORS
}
static double dispatch_to_CDouble(const Tensor & self) {
template<typename T>
static T dispatch_to(const Tensor & self) {
pybind11::gil_scoped_release no_gil;
OptionalDeviceGuard device_guard(device_of(self));
if (self.sym_numel() != 1) {
throw ValueError("only one element tensors can be converted to Python scalars");
}
return self.item<double>();
}
static c10::complex<double> dispatch_to_CComplexDouble(const Tensor & self) {
pybind11::gil_scoped_release no_gil;
OptionalDeviceGuard device_guard(device_of(self));
if (self.sym_numel() != 1) {
throw ValueError("only one element tensors can be converted to Python scalars");
}
return self.item<c10::complex<double>>();
}
static int64_t dispatch_to_CLong(const Tensor & self) {
pybind11::gil_scoped_release no_gil;
OptionalDeviceGuard device_guard(device_of(self));
if (self.sym_numel() != 1) {
throw ValueError("only one element tensors can be converted to Python scalars");
}
return self.item<int64_t>();
return self.template item<T>();
}
static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
@ -349,7 +332,7 @@ static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
}
jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = THPVariable_Unpack(self);
return wrap(dispatch_to_CDouble(self_));
return wrap(dispatch_to<double>(self_));
END_HANDLE_TH_ERRORS
}
@ -360,7 +343,7 @@ static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) {
}
jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW);
auto& self_ = THPVariable_Unpack(self);
return wrap(dispatch_to_CComplexDouble(self_));
return wrap(dispatch_to<c10::complex<double>>(self_));
END_HANDLE_TH_ERRORS
}
@ -374,9 +357,9 @@ static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) {
if (isFloatingType(self_.scalar_type())) {
// we can't dispatch to item<int64_t> here because we want to avoid ATen overflow checks;
// the python integral type (long in python2) can't overflow.
return THPUtils_packDoubleAsInt(dispatch_to_CDouble(self_));
return THPUtils_packDoubleAsInt(dispatch_to<double>(self_));
} else {
return wrap(dispatch_to_CLong(self_));
return wrap(dispatch_to<int64_t>(self_));
}
END_HANDLE_TH_ERRORS
}
@ -394,7 +377,7 @@ static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) {
if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true) || self_.sym_numel() != 1) {
throw TypeError("only integer tensors of a single element can be converted to an index");
}
return wrap(dispatch_to_CLong(self_));
return wrap(dispatch_to<int64_t>(self_));
END_HANDLE_TH_ERRORS
}