onnxruntime/orttraining/orttraining/python/training/ortmodule
pengwa 6b7bce5ec9
Model post process for zero stage3 training (#17187)
### Model post process for zero stage3 training

This is the last change to make single GPU/Multiple GPUs run pass. 

Design details:
https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9

`PyTorch` runs with ZeROOffloadSubscriber:

```
  model = prepare_model(...)
  from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
  configure_ort_compatible_zero_stage3()
```

`ORTModule` runs with ZeROOffloadSubscriber:

```
  os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1'
  from onnxruntime.training.ortmodule import ORTModule
  model = ORTModule(self.model)
```

It will be fairly easy to debug convergence issue if both ORT and
PyTorch can run the same offload path.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
2023-09-22 08:54:25 +08:00
..
experimental Disable PERF* rules in ruff to allow better readability (#16834) 2023-07-25 15:38:22 -07:00
torch_cpp_extensions PythonOp Enhancement: Bool and Tuple[Bool] Constants, Materialize Grads, Empty Inputs, Save In Context (#16828) 2023-08-15 13:31:04 +08:00
__init__.py Change RuntimeError to ImportError (#17380) 2023-09-01 09:56:40 +08:00
_custom_autograd_function.py Manage ORTModule configurations consistently (#16396) 2023-06-27 19:19:36 +08:00
_custom_autograd_function_exporter.py Model post process for zero stage3 training (#17187) 2023-09-22 08:54:25 +08:00
_custom_autograd_function_runner.py Model post process for zero stage3 training (#17187) 2023-09-22 08:54:25 +08:00
_custom_gradient_registry.py [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
_custom_op_symbolic_registry.py Introduce ZeROOffloadSubscriber for ORTModule (#17006) 2023-08-25 00:15:22 +08:00
_execution_agent.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_fallback.py Consolidate ORTModule logging (#16078) 2023-06-01 10:09:12 +08:00
_fallback_exceptions.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_gradient_accumulation_manager.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_graph_execution_interface.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_graph_execution_manager.py Model post process for zero stage3 training (#17187) 2023-09-22 08:54:25 +08:00
_graph_execution_manager_factory.py ORTModule log clean up (#16795) 2023-07-26 12:42:50 +08:00
_inference_manager.py Model post process for zero stage3 training (#17187) 2023-09-22 08:54:25 +08:00
_io.py Model post process for zero stage3 training (#17187) 2023-09-22 08:54:25 +08:00
_logger.py Fix few small bugs (#17019) 2023-08-07 14:01:36 +08:00
_onnx_models.py Save optimized pre_grad graph once ready (#16816) 2023-08-02 14:05:26 +08:00
_runtime_inspector.py ORTModule log clean up (#16795) 2023-07-26 12:42:50 +08:00
_torch_module_factory.py Manage ORTModule configurations consistently (#16396) 2023-06-27 19:19:36 +08:00
_torch_module_interface.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_torch_module_ort.py ORTModule log clean up (#16795) 2023-07-26 12:42:50 +08:00
_torch_module_pytorch.py Adopt linrtunner as the linting tool - take 2 (#15085) 2023-03-24 15:29:03 -07:00
_training_manager.py Model post process for zero stage3 training (#17187) 2023-09-22 08:54:25 +08:00
_utils.py Use full qualified name for PythonOp export (#17021) 2023-08-09 10:58:33 +08:00
_zero_stage3_compatibility.py Model post process for zero stage3 training (#17187) 2023-09-22 08:54:25 +08:00
graph_transformer_registry.py [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
options.py Introduce ZeROOffloadSubscriber for ORTModule (#17006) 2023-08-25 00:15:22 +08:00
ortmodule.py ORTModule log clean up (#16795) 2023-07-26 12:42:50 +08:00