mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
### Improve perf for stage3 training - first wave Port existing PythonOp/PythonOpGrad python runner to C++, also introduce an unsafe run mode (to skip inplace, save for backward, materrialized grad detection on the fly). This reduce the overhead from XX~XXX us to X ~ lower end of XX us . In LLAMA2 7B training with 8x32GV100, we have observed 6.7% gains over PyTorch. (1.59 v.s. 1.49it/s) Peak memory also dropped from 31GB to 28GB. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|---|---|---|
| .. | ||
| orttraining | ||
| tools | ||