onnxruntime/orttraining/orttraining/python
pengwa 5eda79bdd3
Improve perf for stage3 training (#18099)
### Improve perf for stage3 training - first wave

Port existing PythonOp/PythonOpGrad python runner to C++, also introduce
an unsafe run mode (to skip inplace, save for backward, materrialized
grad detection on the fly).

This reduce the overhead from XX~XXX us to X ~ lower end of XX us . In
LLAMA2 7B training with 8x32GV100, we have observed 6.7% gains over
PyTorch. (1.59 v.s. 1.49it/s)

Peak memory also dropped from 31GB to 28GB.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-12-15 13:32:19 +08:00
..
training Improve perf for stage3 training (#18099) 2023-12-15 13:32:19 +08:00
orttraining_pybind_common.h Re-work global objects dependancies in pybind layer. (#14941) 2023-03-10 13:55:31 -08:00
orttraining_pybind_state.cc Improve perf for stage3 training (#18099) 2023-12-15 13:32:19 +08:00
orttraining_python_module.cc Fix warning C4003 in ORT python binding code (#18612) 2023-11-30 08:07:47 -08:00
orttraining_python_module_eager.h Run clang-format in CI (#15524) 2023-04-18 09:26:58 -07:00
pt_patch.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00