onnxruntime/orttraining/orttraining/python/training
pengwa 5eda79bdd3
Improve perf for stage3 training (#18099)
### Improve perf for stage3 training - first wave

Port existing PythonOp/PythonOpGrad python runner to C++, also introduce
an unsafe run mode (to skip inplace, save for backward, materrialized
grad detection on the fly).

This reduce the overhead from XX~XXX us to X ~ lower end of XX us . In
LLAMA2 7B training with 8x32GV100, we have observed 6.7% gains over
PyTorch. (1.59 v.s. 1.49it/s)

Peak memory also dropped from 31GB to 28GB.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-12-15 13:32:19 +08:00
..
amp [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
api [On-Device Training] Expose Parameters through the Training API (#17364) 2023-09-25 20:03:24 -07:00
experimental Manage ORTModule configurations consistently (#16396) 2023-06-27 19:19:36 +08:00
onnxblock [Linter] Bump ruff and remove pylint (#17797) 2023-10-05 21:07:33 -07:00
optim FP16 optimizer automatically detect DeepSpeed compatibility (#18084) 2023-10-25 15:11:02 +08:00
ort_triton [ORTModule] Remove Unused Arguments from Generated Triton Code (#18636) 2023-11-30 18:32:36 +08:00
ortmodule Improve perf for stage3 training (#18099) 2023-12-15 13:32:19 +08:00
torchdynamo [Dort] Run type promotion pass to resolve dtype discrepancy (#18516) 2023-12-01 09:36:18 -08:00
utils Improve perf for stage3 training (#18099) 2023-12-15 13:32:19 +08:00
__init__.py Removed all the deprecated python training code and related tests and utils (#18333) 2023-11-17 18:19:21 -08:00
_utils.py Removed all the deprecated python training code and related tests and utils (#18333) 2023-11-17 18:19:21 -08:00
artifacts.py Fix opset version of the optimizer in function generate_artifacts (#18300) 2023-11-22 09:15:11 -08:00