onnxruntime/orttraining/orttraining/python/training/ortmodule
pengwa 1150b1f81e
ORTModule memory improvement (#18924)
## Dependency

https://github.com/microsoft/onnxruntime/pull/19007

## ORTModule memory efficient gradient management

Previously I have tried to solve the coarsed-grained gradient
accumulation/update problem in ORTModule with
https://github.com/microsoft/onnxruntime/pull/8979, while that
resolution somehow is not fully validated with DDP or there is user
hooks on the gradient accumulation on torch parameter.

This PR is addressing the problem in the similar approach as PR 8979,
e.g. trigger gradient accumulation once ORT computed the grad, but
instead of use a AccumulateGrad op, this time with a ONNX operator
PythonOp, internally it will call param.backward(grad), which will help
handle all related hooks correctly.


## Design

Check the details from


https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ

## Convergence Validation:


![image](https://github.com/microsoft/onnxruntime/assets/10530022/ccf3a213-e815-4b23-b759-165033b2d9fe)

differences are on mostly 0.000x, sometimes 0.00x, which may comes from
the different order gradient apply happens before or after this change
(on deepspeed zero stage 2)


## TODO

Consolidate the logic with Stage3's similar logic.
2024-01-16 08:57:37 +08:00
..
experimental Disable PERF* rules in ruff to allow better readability (#16834) 2023-07-25 15:38:22 -07:00
graph_optimizers [ORTModule] Adjust Attention Patterns for Efficient Attention ATen Fallback (#18471) 2023-11-22 15:24:05 +08:00
torch_cpp_extensions Minor fixes (#18949) 2023-12-28 20:01:06 +08:00
__init__.py [ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959) 2023-10-27 10:29:27 +08:00
_custom_autograd_function.py Improve perf for stage3 training (#18099) 2023-12-15 13:32:19 +08:00
_custom_autograd_function_exporter.py Improve perf for stage3 training (#18099) 2023-12-15 13:32:19 +08:00
_custom_gradient_registry.py ResizeGrad CUDA/ROCM kernel implementation (#17772) 2023-10-20 11:39:57 -07:00
_custom_op_symbolic_registry.py Optimize 4bit Qlora training (#18131) 2023-11-02 09:46:11 -07:00
_execution_agent.py Memory optimization refactor and refinement (#17481) 2023-11-23 11:39:00 +08:00
_fallback.py Bump linter versions (#18341) 2023-11-08 13:04:40 -08:00
_fallback_exceptions.py
_gradient_accumulation_manager.py
_graph_execution_interface.py
_graph_execution_manager.py ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
_graph_execution_manager_factory.py ORTModule log clean up (#16795) 2023-07-26 12:42:50 +08:00
_inference_manager.py Tune ORTModule logging experience a bit (#18298) 2023-11-08 17:42:50 +08:00
_io.py Skip module clone for preparing large model export (#18663) 2023-12-05 12:41:17 -08:00
_logger.py Memory optimization refactor and refinement (#17481) 2023-11-23 11:39:00 +08:00
_mem_efficient_grad_mgmt.py ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
_onnx_models.py ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
_pythonop_helper.py ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
_runtime_inspector.py Allow layer-wise recompute (#18566) 2023-12-12 08:44:05 +08:00
_torch_module_factory.py
_torch_module_interface.py
_torch_module_ort.py ORTModule log clean up (#16795) 2023-07-26 12:42:50 +08:00
_torch_module_pytorch.py
_training_manager.py ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
_utils.py Use full qualified name for PythonOp export (#17021) 2023-08-09 10:58:33 +08:00
_zero_stage3_compatibility.py Improve perf for stage3 training (#18099) 2023-12-15 13:32:19 +08:00
graph_optimizer_registry.py [ORTModule] ATen Efficient Attention and Triton Flash Attention (#17959) 2023-10-27 10:29:27 +08:00
options.py ORTModule memory improvement (#18924) 2024-01-16 08:57:37 +08:00
ortmodule.py ORTModule log clean up (#16795) 2023-07-26 12:42:50 +08:00