onnxruntime/orttraining/orttraining/python/training
pengwa 1150b1f81e
ORTModule memory improvement (#18924)
## Dependency

https://github.com/microsoft/onnxruntime/pull/19007

## ORTModule memory efficient gradient management

Previously I have tried to solve the coarsed-grained gradient
accumulation/update problem in ORTModule with
https://github.com/microsoft/onnxruntime/pull/8979, while that
resolution somehow is not fully validated with DDP or there is user
hooks on the gradient accumulation on torch parameter.

This PR is addressing the problem in the similar approach as PR 8979,
e.g. trigger gradient accumulation once ORT computed the grad, but
instead of use a AccumulateGrad op, this time with a ONNX operator
PythonOp, internally it will call param.backward(grad), which will help
handle all related hooks correctly.


## Design

Check the details from


https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ

## Convergence Validation:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/ccf3a213-e815-4b23-b759-165033b2d9fe)

differences are on mostly 0.000x, sometimes 0.00x, which may comes from
the different order gradient apply happens before or after this change
(on deepspeed zero stage 2)


## TODO

Consolidate the logic with Stage3's similar logic.
2024-01-16 08:57:37 +08:00
..
amp [Better Engineering] Bump ruff to 0.0.278 and fix new lint errors (#16789) 2023-07-21 12:53:41 -07:00
api [On-Device Training] Expose Parameters through the Training API (#17364) 2023-09-25 20:03:24 -07:00
experimental Manage ORTModule configurations consistently (#16396) 2023-06-27 19:19:36 +08:00
onnxblock Offline tooling for training to use reduction with keepdims=False (#19027) 2024-01-11 10:51:23 -08:00
optim FP16 optimizer automatically detect DeepSpeed compatibility (#18084) 2023-10-25 15:11:02 +08:00
ort_triton [ORTModule] Remove Unused Arguments from Generated Triton Code (#18636) 2023-11-30 18:32:36 +08:00
ortmodule ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
utils ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
__init__.py Removed all the deprecated python training code and related tests and utils (#18333) 2023-11-17 18:19:21 -08:00
_utils.py Removed all the deprecated python training code and related tests and utils (#18333) 2023-11-17 18:19:21 -08:00
artifacts.py Fix opset version of the optimizer in function generate_artifacts (#18300) 2023-11-22 09:15:11 -08:00