Update python binding for in-place foreach to return List[Tensor] (#121405)

fixes #104817
taking over #118622

```c++
// _foreach_atan_
static PyObject * THPVariable__foreach_atan_(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    "_foreach_atan_(TensorList self)",
  }, /*traceable=*/false);

  ParsedArgs<1> parsed_args;
  auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
  if(_r.has_torch_function()) {
    return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
  }
  // aten::_foreach_atan_(Tensor(a!)[] self) -> ()

  // auto dispatch__foreach_atan_ = [](at::TensorList self) -> at::TensorList {
  auto dispatch__foreach_atan_ = [](at::TensorList self) -> void {
    pybind11::gil_scoped_release no_gil;
    at::_foreach_atan_(self);
  };
  dispatch__foreach_atan_(_r.tensorlist(0));
  PyObject* self_tensorlist = _r.args[0];
  Py_INCREF(self_tensorlist);
  return self_tensorlist;
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}
...
// _foreach_div_
static PyObject * THPVariable__foreach_div_(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    "_foreach_div_(TensorList self, ScalarList scalars)",
    "_foreach_div_(TensorList self, Tensor other)",
    "_foreach_div_(TensorList self, TensorList other)",
    "_foreach_div_(TensorList self, Scalar scalar)",
  }, /*traceable=*/false);

  ParsedArgs<2> parsed_args;
  auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
  if(_r.has_torch_function()) {
    return handle_torch_function(_r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
  }
  switch (_r.idx) {
    case 0: {
      // aten::_foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()

      // auto dispatch__foreach_div_ = [](at::TensorList self, at::ArrayRef<at::Scalar> scalars) -> at::TensorList {
      auto dispatch__foreach_div_ = [](at::TensorList self, at::ArrayRef<at::Scalar> scalars) -> void {
        pybind11::gil_scoped_release no_gil;
        at::_foreach_div_(self, scalars);
      };
      dispatch__foreach_div_(_r.tensorlist(0), _r.scalarlist(1));
      PyObject* self_tensorlist = _r.args[0];
      Py_INCREF(self_tensorlist);
      return self_tensorlist;
    }
    case 1: {
      // aten::_foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> ()

      // auto dispatch__foreach_div_ = [](at::TensorList self, const at::Tensor & other) -> at::TensorList {
      auto dispatch__foreach_div_ = [](at::TensorList self, const at::Tensor & other) -> void {
        pybind11::gil_scoped_release no_gil;
        at::_foreach_div_(self, other);
      };
      dispatch__foreach_div_(_r.tensorlist(0), _r.tensor(1));
      PyObject* self_tensorlist = _r.args[0];
      Py_INCREF(self_tensorlist);
      return self_tensorlist;
    }
    case 2: {
      // aten::_foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> ()

      // auto dispatch__foreach_div_ = [](at::TensorList self, at::TensorList other) -> at::TensorList {
      auto dispatch__foreach_div_ = [](at::TensorList self, at::TensorList other) -> void {
        pybind11::gil_scoped_release no_gil;
        at::_foreach_div_(self, other);
      };
      dispatch__foreach_div_(_r.tensorlist(0), _r.tensorlist(1));
      PyObject* self_tensorlist = _r.args[0];
      Py_INCREF(self_tensorlist);
      return self_tensorlist;
    }
    case 3: {
      // aten::_foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()

      // auto dispatch__foreach_div_ = [](at::TensorList self, const at::Scalar & scalar) -> at::TensorList {
      auto dispatch__foreach_div_ = [](at::TensorList self, const at::Scalar & scalar) -> void {
        pybind11::gil_scoped_release no_gil;
        at::_foreach_div_(self, scalar);
      };
      dispatch__foreach_div_(_r.tensorlist(0), _r.scalar(1));
      PyObject* self_tensorlist = _r.args[0];
      Py_INCREF(self_tensorlist);
      return self_tensorlist;
    }
  }
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121405
Approved by: https://github.com/soulitzer
This commit is contained in:
Masaki Kozuki 2024-03-08 20:59:56 +00:00 committed by PyTorch MergeBot
parent d27509c384
commit 82bb06334d
2 changed files with 25 additions and 3 deletions

View file

@ -66,8 +66,9 @@ class ForeachFuncWrapper:
assert mta_called == (expect_fastpath and (not zero_size))
else:
actual = self.func(*inputs, **kwargs)
# note(mkozuki): inplace foreach functions are void functions.
return inputs[0] if self.is_inplace else actual
if self.is_inplace:
assert id(inputs[0]) == id(actual)
return actual
class InplaceForeachVersionBumpCheck:

View file

@ -64,12 +64,14 @@ from torchgen.model import (
BaseOperatorName,
FunctionSchema,
NativeFunction,
SchemaKind,
Type,
Variant,
)
from torchgen.utils import FileManager, split_name_params
from torchgen.yaml_utils import YamlLoader
from .gen_inplace_or_view_type import is_tensor_list_type
from .gen_trace_type import should_trace
#
@ -1352,6 +1354,25 @@ def emit_single_dispatch(
)
if lambda_return == "void":
# Make in-place foreach return `self` at python-binding level.
# ref: https://github.com/pytorch/pytorch/pull/118622#pullrequestreview-1904804954
self_arg = f.func.arguments.self_arg
return_stmt: str
if (
str(f.func.name).startswith("_foreach_")
and f.func.kind() == SchemaKind.inplace
):
# note(crcrpar): `_foreach_pow.ScalarAndTensor` does NOT have its in-place
# variant and it unlikely to have it in the future. Thus it's safe to have the following assert.
assert self_arg is not None and is_tensor_list_type(
self_arg.argument.type
)
return_stmt = """PyObject* self_tensorlist = _r.args[0];
Py_INCREF(self_tensorlist);
return self_tensorlist;
"""
else:
return_stmt = "Py_RETURN_NONE;"
return f"""\
{schema_comment}
{inits}
@ -1360,7 +1381,7 @@ auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
{dispatch_callee}({dispatch_args});
}};
dispatch_{name}({lambda_args}){set_requires_grad};
Py_RETURN_NONE;
{return_stmt}
"""
else:
typename = structseq_typenames.get(gen_structseq_typename_key(f))