mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[BE] Get rid of code duplication (#110619)
Replace `dispatch_to_CDouble`, `dispatch_to_CLong` and `dispatch_to_CComplexDouble` with `dispatch_to<T>` template <!-- copilot:poem --> ### <samp>🤖 Generated by Copilot at c3d9d01</samp> > _Sing, O Muse, of the clever coder who devised_ > _A wondrous template function, `dispatch_to<T>`, that could_ > _Handle with ease the various scalar types that vexed_ > _The previous code, which was verbose and dull as wood._ Pull Request resolved: https://github.com/pytorch/pytorch/pull/110619 Approved by: https://github.com/soulitzer, https://github.com/albanD ghstack dependencies: #110618
This commit is contained in:
parent
82e353fffc
commit
be02103786
1 changed files with 8 additions and 25 deletions
|
|
@ -315,31 +315,14 @@ static Tensor dispatch_copy_(const Tensor & self, const Tensor & other, bool non
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static double dispatch_to_CDouble(const Tensor & self) {
|
||||
template<typename T>
|
||||
static T dispatch_to(const Tensor & self) {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
OptionalDeviceGuard device_guard(device_of(self));
|
||||
if (self.sym_numel() != 1) {
|
||||
throw ValueError("only one element tensors can be converted to Python scalars");
|
||||
}
|
||||
return self.item<double>();
|
||||
}
|
||||
|
||||
static c10::complex<double> dispatch_to_CComplexDouble(const Tensor & self) {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
OptionalDeviceGuard device_guard(device_of(self));
|
||||
if (self.sym_numel() != 1) {
|
||||
throw ValueError("only one element tensors can be converted to Python scalars");
|
||||
}
|
||||
return self.item<c10::complex<double>>();
|
||||
}
|
||||
|
||||
static int64_t dispatch_to_CLong(const Tensor & self) {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
OptionalDeviceGuard device_guard(device_of(self));
|
||||
if (self.sym_numel() != 1) {
|
||||
throw ValueError("only one element tensors can be converted to Python scalars");
|
||||
}
|
||||
return self.item<int64_t>();
|
||||
return self.template item<T>();
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
|
||||
|
|
@ -349,7 +332,7 @@ static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
|
|||
}
|
||||
jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
return wrap(dispatch_to_CDouble(self_));
|
||||
return wrap(dispatch_to<double>(self_));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
|
@ -360,7 +343,7 @@ static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) {
|
|||
}
|
||||
jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW);
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
return wrap(dispatch_to_CComplexDouble(self_));
|
||||
return wrap(dispatch_to<c10::complex<double>>(self_));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
|
@ -374,9 +357,9 @@ static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) {
|
|||
if (isFloatingType(self_.scalar_type())) {
|
||||
// we can't dispatch to item<int64_t> here because we want to avoid ATen overflow checks;
|
||||
// the python integral type (long in python2) can't overflow.
|
||||
return THPUtils_packDoubleAsInt(dispatch_to_CDouble(self_));
|
||||
return THPUtils_packDoubleAsInt(dispatch_to<double>(self_));
|
||||
} else {
|
||||
return wrap(dispatch_to_CLong(self_));
|
||||
return wrap(dispatch_to<int64_t>(self_));
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
|
@ -394,7 +377,7 @@ static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) {
|
|||
if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true) || self_.sym_numel() != 1) {
|
||||
throw TypeError("only integer tensors of a single element can be converted to an index");
|
||||
}
|
||||
return wrap(dispatch_to_CLong(self_));
|
||||
return wrap(dispatch_to<int64_t>(self_));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue