mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update python binding for in-place foreach to return List[Tensor] (#121405)
fixes #104817 taking over #118622 ```c++ // _foreach_atan_ static PyObject * THPVariable__foreach_atan_(PyObject* self_, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "_foreach_atan_(TensorList self)", }, /*traceable=*/false); ParsedArgs<1> parsed_args; auto _r = parser.parse(nullptr, args, kwargs, parsed_args); if(_r.has_torch_function()) { return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); } // aten::_foreach_atan_(Tensor(a!)[] self) -> () // auto dispatch__foreach_atan_ = [](at::TensorList self) -> at::TensorList { auto dispatch__foreach_atan_ = [](at::TensorList self) -> void { pybind11::gil_scoped_release no_gil; at::_foreach_atan_(self); }; dispatch__foreach_atan_(_r.tensorlist(0)); PyObject* self_tensorlist = _r.args[0]; Py_INCREF(self_tensorlist); return self_tensorlist; Py_RETURN_NONE; END_HANDLE_TH_ERRORS } ... // _foreach_div_ static PyObject * THPVariable__foreach_div_(PyObject* self_, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "_foreach_div_(TensorList self, ScalarList scalars)", "_foreach_div_(TensorList self, Tensor other)", "_foreach_div_(TensorList self, TensorList other)", "_foreach_div_(TensorList self, Scalar scalar)", }, /*traceable=*/false); ParsedArgs<2> parsed_args; auto _r = parser.parse(nullptr, args, kwargs, parsed_args); if(_r.has_torch_function()) { return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); } switch (_r.idx) { case 0: { // aten::_foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () // auto dispatch__foreach_div_ = [](at::TensorList self, at::ArrayRef<at::Scalar> scalars) -> at::TensorList { auto dispatch__foreach_div_ = [](at::TensorList self, at::ArrayRef<at::Scalar> scalars) -> void { pybind11::gil_scoped_release no_gil; at::_foreach_div_(self, scalars); }; dispatch__foreach_div_(_r.tensorlist(0), _r.scalarlist(1)); PyObject* self_tensorlist = _r.args[0]; Py_INCREF(self_tensorlist); return self_tensorlist; } case 1: { // aten::_foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () // auto dispatch__foreach_div_ = [](at::TensorList self, const at::Tensor & other) -> at::TensorList { auto dispatch__foreach_div_ = [](at::TensorList self, const at::Tensor & other) -> void { pybind11::gil_scoped_release no_gil; at::_foreach_div_(self, other); }; dispatch__foreach_div_(_r.tensorlist(0), _r.tensor(1)); PyObject* self_tensorlist = _r.args[0]; Py_INCREF(self_tensorlist); return self_tensorlist; } case 2: { // aten::_foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () // auto dispatch__foreach_div_ = [](at::TensorList self, at::TensorList other) -> at::TensorList { auto dispatch__foreach_div_ = [](at::TensorList self, at::TensorList other) -> void { pybind11::gil_scoped_release no_gil; at::_foreach_div_(self, other); }; dispatch__foreach_div_(_r.tensorlist(0), _r.tensorlist(1)); PyObject* self_tensorlist = _r.args[0]; Py_INCREF(self_tensorlist); return self_tensorlist; } case 3: { // aten::_foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () // auto dispatch__foreach_div_ = [](at::TensorList self, const at::Scalar & scalar) -> at::TensorList { auto dispatch__foreach_div_ = [](at::TensorList self, const at::Scalar & scalar) -> void { pybind11::gil_scoped_release no_gil; at::_foreach_div_(self, scalar); }; dispatch__foreach_div_(_r.tensorlist(0), _r.scalar(1)); PyObject* self_tensorlist = _r.args[0]; Py_INCREF(self_tensorlist); return self_tensorlist; } } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/121405 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
d27509c384
commit
82bb06334d
2 changed files with 25 additions and 3 deletions
|
|
@ -66,8 +66,9 @@ class ForeachFuncWrapper:
|
|||
assert mta_called == (expect_fastpath and (not zero_size))
|
||||
else:
|
||||
actual = self.func(*inputs, **kwargs)
|
||||
# note(mkozuki): inplace foreach functions are void functions.
|
||||
return inputs[0] if self.is_inplace else actual
|
||||
if self.is_inplace:
|
||||
assert id(inputs[0]) == id(actual)
|
||||
return actual
|
||||
|
||||
|
||||
class InplaceForeachVersionBumpCheck:
|
||||
|
|
|
|||
|
|
@ -64,12 +64,14 @@ from torchgen.model import (
|
|||
BaseOperatorName,
|
||||
FunctionSchema,
|
||||
NativeFunction,
|
||||
SchemaKind,
|
||||
Type,
|
||||
Variant,
|
||||
)
|
||||
from torchgen.utils import FileManager, split_name_params
|
||||
from torchgen.yaml_utils import YamlLoader
|
||||
|
||||
from .gen_inplace_or_view_type import is_tensor_list_type
|
||||
from .gen_trace_type import should_trace
|
||||
|
||||
#
|
||||
|
|
@ -1352,6 +1354,25 @@ def emit_single_dispatch(
|
|||
)
|
||||
|
||||
if lambda_return == "void":
|
||||
# Make in-place foreach return `self` at python-binding level.
|
||||
# ref: https://github.com/pytorch/pytorch/pull/118622#pullrequestreview-1904804954
|
||||
self_arg = f.func.arguments.self_arg
|
||||
return_stmt: str
|
||||
if (
|
||||
str(f.func.name).startswith("_foreach_")
|
||||
and f.func.kind() == SchemaKind.inplace
|
||||
):
|
||||
# note(crcrpar): `_foreach_pow.ScalarAndTensor` does NOT have its in-place
|
||||
# variant and it unlikely to have it in the future. Thus it's safe to have the following assert.
|
||||
assert self_arg is not None and is_tensor_list_type(
|
||||
self_arg.argument.type
|
||||
)
|
||||
return_stmt = """PyObject* self_tensorlist = _r.args[0];
|
||||
Py_INCREF(self_tensorlist);
|
||||
return self_tensorlist;
|
||||
"""
|
||||
else:
|
||||
return_stmt = "Py_RETURN_NONE;"
|
||||
return f"""\
|
||||
{schema_comment}
|
||||
{inits}
|
||||
|
|
@ -1360,7 +1381,7 @@ auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
|
|||
{dispatch_callee}({dispatch_args});
|
||||
}};
|
||||
dispatch_{name}({lambda_args}){set_requires_grad};
|
||||
Py_RETURN_NONE;
|
||||
{return_stmt}
|
||||
"""
|
||||
else:
|
||||
typename = structseq_typenames.get(gen_structseq_typename_key(f))
|
||||
|
|
|
|||
Loading…
Reference in a new issue