onnxruntime/orttraining/orttraining/python/training/utils
pengwa 6b7bce5ec9
Model post process for zero stage3 training (#17187)
### Model post process for zero stage3 training

This is the last change to make single GPU/Multiple GPUs run pass. 

Design details:
https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9

`PyTorch` runs with ZeROOffloadSubscriber:

```
  model = prepare_model(...)
  from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
  configure_ort_compatible_zero_stage3()
```

`ORTModule` runs with ZeROOffloadSubscriber:

```
  os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1'
  from onnxruntime.training.ortmodule import ORTModule
  model = ORTModule(self.model)
```

It will be fairly easy to debug convergence issue if both ORT and
PyTorch can run the same offload path.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-09-22 08:54:25 +08:00
..
data Bump ruff in CI (#15533) 2023-04-17 10:11:44 -07:00
hooks Model post process for zero stage3 training (#17187) 2023-09-22 08:54:25 +08:00
__init__.py Model post process for zero stage3 training (#17187) 2023-09-22 08:54:25 +08:00
torch_io_helper.py Refactor schema extraction and output unflattening (#16894) 2023-08-04 13:58:21 +08:00
torch_type_map.py Model post process for zero stage3 training (#17187) 2023-09-22 08:54:25 +08:00