onnxruntime/orttraining
pengwa 7bec80d92a
Fix reference count for autograd.Function (#15121)
### Fix reference count for autograd

When PythonOp kernel initialized, `AddPointerScalarArgs` creates
`const_args_` which put all non-tensor references (including
ProcessGroup, string, or other user types) in it.

In kernel's destructor, all ref cnt got decreased for `const_args_`. 


```
void PythonOpBase::Clear() {
  for (auto ptr : const_args_) {
    auto obj = reinterpret_cast<PyObject*>(ptr);
    Py_DECREF(obj);
  }
}
```

It means, we did not increase cnt, but just decrease cnt. Running the
unit, segmentation fault will be thrown. The simple fix is to remove the
Py_DECREF for those pointer-type constant inputs triggered by kernel
destructor.

NONTENSOR_OBJECT_POINTER_STORE is the place we increase the reference
during export, then the reference will remain until the python program
terminates.


Additionally tunings:
1. Move some logs into verbose instead of warning in case of flooding
training logs.
2. Move pointer type ref holding from python side
(NONTENSOR_OBJECT_POINTER_STORE) to
orttraining/orttraining/core/framework/torch/custom_function_register.h.
Then we use a consistent approach to manage all PythonOp related python
object/methonds ref count increasing and decreasing.
2023-03-23 12:51:50 +08:00
..
orttraining Fix reference count for autograd.Function (#15121) 2023-03-23 12:51:50 +08:00
pytorch_frontend_examples Set black's target version (#11370) 2022-04-27 14:52:19 -07:00
tools Refactor training build options (#13964) 2023-01-03 13:28:16 -08:00