diff --git a/test/test_foreach.py b/test/test_foreach.py index a28197add38..41127f45396 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -66,8 +66,9 @@ class ForeachFuncWrapper: assert mta_called == (expect_fastpath and (not zero_size)) else: actual = self.func(*inputs, **kwargs) - # note(mkozuki): inplace foreach functions are void functions. - return inputs[0] if self.is_inplace else actual + if self.is_inplace: + assert id(inputs[0]) == id(actual) + return actual class InplaceForeachVersionBumpCheck: diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 29ccf120590..be942ca5bfb 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -64,12 +64,14 @@ from torchgen.model import ( BaseOperatorName, FunctionSchema, NativeFunction, + SchemaKind, Type, Variant, ) from torchgen.utils import FileManager, split_name_params from torchgen.yaml_utils import YamlLoader +from .gen_inplace_or_view_type import is_tensor_list_type from .gen_trace_type import should_trace # @@ -1352,6 +1354,25 @@ def emit_single_dispatch( ) if lambda_return == "void": + # Make in-place foreach return `self` at python-binding level. + # ref: https://github.com/pytorch/pytorch/pull/118622#pullrequestreview-1904804954 + self_arg = f.func.arguments.self_arg + return_stmt: str + if ( + str(f.func.name).startswith("_foreach_") + and f.func.kind() == SchemaKind.inplace + ): + # note(crcrpar): `_foreach_pow.ScalarAndTensor` does NOT have its in-place + # variant and it unlikely to have it in the future. Thus it's safe to have the following assert. + assert self_arg is not None and is_tensor_list_type( + self_arg.argument.type + ) + return_stmt = """PyObject* self_tensorlist = _r.args[0]; +Py_INCREF(self_tensorlist); +return self_tensorlist; +""" + else: + return_stmt = "Py_RETURN_NONE;" return f"""\ {schema_comment} {inits} @@ -1360,7 +1381,7 @@ auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{ {dispatch_callee}({dispatch_args}); }}; dispatch_{name}({lambda_args}){set_requires_grad}; -Py_RETURN_NONE; +{return_stmt} """ else: typename = structseq_typenames.get(gen_structseq_typename_key(f))