mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-22 22:01:08 +00:00
### Model post process for zero stage3 training This is the last change to make single GPU/Multiple GPUs run pass. Design details: https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9 `PyTorch` runs with ZeROOffloadSubscriber: ``` model = prepare_model(...) from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 configure_ort_compatible_zero_stage3() ``` `ORTModule` runs with ZeROOffloadSubscriber: ``` os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1' from onnxruntime.training.ortmodule import ORTModule model = ORTModule(self.model) ``` It will be fairly easy to debug convergence issue if both ORT and PyTorch can run the same offload path. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> |
||
|---|---|---|
| .. | ||
| data | ||
| hooks | ||
| __init__.py | ||
| torch_io_helper.py | ||
| torch_type_map.py | ||