From 82bb06334da97bf0cbc673e48366d299c80936b4 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 8 Mar 2024 20:59:56 +0000 Subject: [PATCH] Update python binding for in-place foreach to return `List[Tensor]` (#121405) fixes #104817 taking over #118622 ```c++ // _foreach_atan_ static PyObject * THPVariable__foreach_atan_(PyObject* self_, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "_foreach_atan_(TensorList self)", }, /*traceable=*/false); ParsedArgs<1> parsed_args; auto _r = parser.parse(nullptr, args, kwargs, parsed_args); if(_r.has_torch_function()) { return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); } // aten::_foreach_atan_(Tensor(a!)[] self) -> () // auto dispatch__foreach_atan_ = [](at::TensorList self) -> at::TensorList { auto dispatch__foreach_atan_ = [](at::TensorList self) -> void { pybind11::gil_scoped_release no_gil; at::_foreach_atan_(self); }; dispatch__foreach_atan_(_r.tensorlist(0)); PyObject* self_tensorlist = _r.args[0]; Py_INCREF(self_tensorlist); return self_tensorlist; Py_RETURN_NONE; END_HANDLE_TH_ERRORS } ... // _foreach_div_ static PyObject * THPVariable__foreach_div_(PyObject* self_, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ "_foreach_div_(TensorList self, ScalarList scalars)", "_foreach_div_(TensorList self, Tensor other)", "_foreach_div_(TensorList self, TensorList other)", "_foreach_div_(TensorList self, Scalar scalar)", }, /*traceable=*/false); ParsedArgs<2> parsed_args; auto _r = parser.parse(nullptr, args, kwargs, parsed_args); if(_r.has_torch_function()) { return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch"); } switch (_r.idx) { case 0: { // aten::_foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () // auto dispatch__foreach_div_ = [](at::TensorList self, at::ArrayRef scalars) -> at::TensorList { auto dispatch__foreach_div_ = [](at::TensorList self, at::ArrayRef scalars) -> void { pybind11::gil_scoped_release no_gil; at::_foreach_div_(self, scalars); }; dispatch__foreach_div_(_r.tensorlist(0), _r.scalarlist(1)); PyObject* self_tensorlist = _r.args[0]; Py_INCREF(self_tensorlist); return self_tensorlist; } case 1: { // aten::_foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () // auto dispatch__foreach_div_ = [](at::TensorList self, const at::Tensor & other) -> at::TensorList { auto dispatch__foreach_div_ = [](at::TensorList self, const at::Tensor & other) -> void { pybind11::gil_scoped_release no_gil; at::_foreach_div_(self, other); }; dispatch__foreach_div_(_r.tensorlist(0), _r.tensor(1)); PyObject* self_tensorlist = _r.args[0]; Py_INCREF(self_tensorlist); return self_tensorlist; } case 2: { // aten::_foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () // auto dispatch__foreach_div_ = [](at::TensorList self, at::TensorList other) -> at::TensorList { auto dispatch__foreach_div_ = [](at::TensorList self, at::TensorList other) -> void { pybind11::gil_scoped_release no_gil; at::_foreach_div_(self, other); }; dispatch__foreach_div_(_r.tensorlist(0), _r.tensorlist(1)); PyObject* self_tensorlist = _r.args[0]; Py_INCREF(self_tensorlist); return self_tensorlist; } case 3: { // aten::_foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () // auto dispatch__foreach_div_ = [](at::TensorList self, const at::Scalar & scalar) -> at::TensorList { auto dispatch__foreach_div_ = [](at::TensorList self, const at::Scalar & scalar) -> void { pybind11::gil_scoped_release no_gil; at::_foreach_div_(self, scalar); }; dispatch__foreach_div_(_r.tensorlist(0), _r.scalar(1)); PyObject* self_tensorlist = _r.args[0]; Py_INCREF(self_tensorlist); return self_tensorlist; } } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/121405 Approved by: https://github.com/soulitzer --- test/test_foreach.py | 5 +++-- tools/autograd/gen_python_functions.py | 23 ++++++++++++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/test/test_foreach.py b/test/test_foreach.py index a28197add38..41127f45396 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -66,8 +66,9 @@ class ForeachFuncWrapper: assert mta_called == (expect_fastpath and (not zero_size)) else: actual = self.func(*inputs, **kwargs) - # note(mkozuki): inplace foreach functions are void functions. - return inputs[0] if self.is_inplace else actual + if self.is_inplace: + assert id(inputs[0]) == id(actual) + return actual class InplaceForeachVersionBumpCheck: diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 29ccf120590..be942ca5bfb 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -64,12 +64,14 @@ from torchgen.model import ( BaseOperatorName, FunctionSchema, NativeFunction, + SchemaKind, Type, Variant, ) from torchgen.utils import FileManager, split_name_params from torchgen.yaml_utils import YamlLoader +from .gen_inplace_or_view_type import is_tensor_list_type from .gen_trace_type import should_trace # @@ -1352,6 +1354,25 @@ def emit_single_dispatch( ) if lambda_return == "void": + # Make in-place foreach return `self` at python-binding level. + # ref: https://github.com/pytorch/pytorch/pull/118622#pullrequestreview-1904804954 + self_arg = f.func.arguments.self_arg + return_stmt: str + if ( + str(f.func.name).startswith("_foreach_") + and f.func.kind() == SchemaKind.inplace + ): + # note(crcrpar): `_foreach_pow.ScalarAndTensor` does NOT have its in-place + # variant and it unlikely to have it in the future. Thus it's safe to have the following assert. + assert self_arg is not None and is_tensor_list_type( + self_arg.argument.type + ) + return_stmt = """PyObject* self_tensorlist = _r.args[0]; +Py_INCREF(self_tensorlist); +return self_tensorlist; +""" + else: + return_stmt = "Py_RETURN_NONE;" return f"""\ {schema_comment} {inits} @@ -1360,7 +1381,7 @@ auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ {dispatch_callee}({dispatch_args}); }}; dispatch_{name}({lambda_args}){set_requires_grad}; -Py_RETURN_NONE; +{return_stmt} """ else: typename = structseq_typenames.get(gen_structseq_typename_key(f))