onnxruntime/orttraining
pengwa abf9765d73
PythonOp Enhancement: Bool and Tuple[Bool] Constants, Materialize Grads, Empty Inputs, Save In Context (#16828)
### PythonOp Enhancement: Bool and Tuple[Bool] Constants, Materialize
Grads, Empty Inputs, Save In Context

1. Support `bool` or `Tuple[bool]` constant type in inputs.

2. Support `ctx.set_materialize_grads(True|False)`

3. Backward op can accept empty input (that don't require grad)

4. Special handling for ORT tensors are saved in context
**Scenario**: a tensor is generated by ORT, then it might be saved for
backward by `ctx.save_for_backward(tensor)`, while `tensor`'s reference
count is not increased in ORT's allocation plan, so it is possible ORT
release the tensor data, before backward usage.
**Currently**: we copy every tensor before running
autograd.Function.forward(), this might be a problem for cases there are
many PythonOp (for example zero stage 3).
**Proposal**: To avoid those unnecessary copies for tensors that are not
saved in context, this change introduced a `_GlobalOpKernelInfoMap`.
During the kernel first run, we will anyway copy all tensors generated
from ORT, and give it to torch.autograd.Function for run, then we check
whether the inputs needs to be saved in context, and save the input
index that needs saving in `_GlobalOpKernelInfoMap`. Then for later
iterations, we just copy what is needed.
2023-08-15 13:31:04 +08:00
..
orttraining PythonOp Enhancement: Bool and Tuple[Bool] Constants, Materialize Grads, Empty Inputs, Save In Context (#16828) 2023-08-15 13:31:04 +08:00
pytorch_frontend_examples [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
tools [ROCm] Update CI based on ubuntu 22.04 (#17076) 2023-08-10 09:51:29 -07:00